当前位置: 首页 > news >正文

sparkml 多列共享labelEncoder - 详解

背景描述

比如两列 from城市 to城市

我们的需求是两侧同一个城市必须labelEncoder后编码相同.

代码

from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml.feature import StringIndexer
class SharedLabelEncoder:"""共享标签编码器:对多列使用同一套 label->index 映射。- handle_invalid: "keep"(未知值编码为未知索引)、"skip"(返回 None)、"error"(抛错)- unknown 索引默认等于 len(labels),仅在 handle_invalid="keep" 时使用。"""def __init__(self, labels: Optional[List[str]] = None, handle_invalid: str = "keep"):self.labels: List[str] = labels or []self.label_to_index: Dict[str, int] = {v: i for i, v in enumerate(self.labels)}self.handle_invalid = handle_invaliddef fit(self, df, cols: Iterable[str]) -> "SharedLabelEncoder":# 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labelsstacked = Nonefor c in cols:col_df = df.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})stacked = col_df if stacked is None else stacked.unionByName(col_df)indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")model = indexer.fit(stacked)self.labels = list(model.labels)self.label_to_index = {v: i for i, v in enumerate(self.labels)}return selfdef _build_udf(self, spark: SparkSession):m_b = spark.sparkContext.broadcast(self.label_to_index)unknown_index = len(self.labels)def map_value(v: Optional[str]) -> Optional[int]:if v is None:return None if self.handle_invalid == "skip" else unknown_index if self.handle_invalid == "keep" else Noneidx = m_b.value.get(v)if idx is not None:return idxif self.handle_invalid == "keep":return unknown_indexif self.handle_invalid == "skip":return Noneraise ValueError(f"未知标签: {v}")return F.udf(map_value, T.IntegerType())def transform(self, df, input_cols: Iterable[str], suffix: str = "_idx"):udf_map = self._build_udf(df.sparkSession)out = dffor c in input_cols:out = out.withColumn(c + suffix, udf_map(F.col(c).cast(T.StringType())))return outdef save(self, path: str):import jsonobj = {"labels": self.labels, "handle_invalid": self.handle_invalid}with open(path, "w", encoding="utf-8") as f:json.dump(obj, f, ensure_ascii=False)@staticmethoddef load(path: str) -> "SharedLabelEncoder":import jsonwith open(path, "r", encoding="utf-8") as f:obj = json.load(f)return SharedLabelEncoder(labels=obj.get("labels", []), handle_invalid=obj.get("handle_invalid", "keep"))
def main():spark = SparkSession.builder.appName("shared_label_encoder").getOrCreate()spark.sparkContext.setLogLevel("ERROR")data = [(1, "北京", "上海", 1),(2, "上海", "北京", 0),(3, "广州", "深圳", 1),(4, "深圳", "广州", 0),(5, "北京", "广州", 1),(6, "上海", "深圳", 0),]columns = ["id", "origin_city", "dest_city", "label"]df = spark.createDataFrame(data, schema=columns)# 拟合共享编码器(基于两列)encoder = SharedLabelEncoder(handle_invalid="keep").fit(df, ["origin_city", "dest_city"])# 变换两列到相同索引空间out_df = encoder.transform(df, ["origin_city", "dest_city"])print("编码结果:")out_df.show(truncate=False)# 保存/加载并复用path = "./shared_label_encoder_city.json"encoder.save(path)encoder2 = SharedLabelEncoder.load(path)new_df = spark.createDataFrame([(7, "北京", "杭州", 1)], schema=columns)  # 杭州为新值out_new = encoder2.transform(new_df, ["origin_city", "dest_city"])print("加载导出后的encoder并复用:")out_new.show(truncate=False)
main()

输出

编码结果:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|1  |北京       |上海     |1    |1              |0            |
|2  |上海       |北京     |0    |0              |1            |
|3  |广州       |深圳     |1    |2              |3            |
|4  |深圳       |广州     |0    |3              |2            |
|5  |北京       |广州     |1    |1              |2            |
|6  |上海       |深圳     |0    |0              |3            |
+---+-----------+---------+-----+---------------+-------------+加载导出后的encoder并复用:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|7  |北京       |杭州     |1    |1              |4            |
+---+-----------+---------+-----+---------------+-------------+

http://www.hskmm.com/?act=detail&tid=29146

相关文章:

  • 一键解决MetaHuman播放动画时头部穿模问题
  • 忽然很好奇为什么素未谋面的大家都知道我是学姐?
  • UE网络编程完全指南:UDP TCP WebSocket实现详解
  • 配置Nginx服务器在Ubuntu平台上
  • 缓存一致性验证秘笈
  • 从十五岁的今天写给十六岁的明天
  • kali U盘启动持久化
  • 深入解析:Telerik UI for ASP.NET MVC 2025 Q3
  • Java依记 DAY02 - I
  • 元推理:汉字的发音,同音也是某种同构?
  • 题解:qoj7759 Permutation Counting 2
  • WAV 转 flac 格式
  • EtherCAT芯片没有倍福授权的风险
  • 为何是「对话式」智能体?因为人类本能丨对话式智能体专场,Convo AIRTE2025
  • 2014-2024高考真题考点分布详细分析(另附完整高考真题下载) - 详解
  • P4147 玉蟾宫(最大子矩形)
  • 2025 年 10 月西安房屋鉴定公司最新推荐排行榜:覆盖房屋安全评估、结构检测、承载力鉴定、危房鉴定领域,助您选专业机构
  • 完整教程:HAProxy 完整指南:简介、负载均衡原理与安装配置
  • K
  • 阿里发布「夸克 AI 眼镜」:融合阿里购物、地图、支付生态;苹果拟收购计算机视觉初创 Prompt AI丨日报
  • 在AI技术唾手可得的时代,挖掘新需求成为制胜关键——某知名AI聊天框架需求探索
  • 数论学习之路
  • 生成式AI实现多模态信息检索技术突破
  • 在运维工作中,如何过滤某个目录在那边什么路径下面?
  • 完整教程:安卓中,kotlin如何写app界面?
  • 移动固态硬盘插入电脑后提示“应该格式化”或“文件系统损坏”如何修复?
  • PHP 15 个高效开发的小技巧
  • AI元人文构想研究:人类拥抱AI的文明新范式
  • 【汇编】汇编语言运行过程
  • 电感式传感器 - 实践