前段时间遇到一个略有挑战性的数据需求,要求高速处理巨量的标志位数据。我费了些折腾功夫在 Spark 里用 SQL 的二进制操作给实现了个,但总觉得应该有更优雅的实现,有一种动作变形的感觉,虽然目前还挺堪用。
下边这是大致需求。有个大宽表 T(大约不超过 5k 列),对 T 的每条数据做某种计算可以得到它们的一些状态,状态的种类是固定的,比如这里的数据只能有 A、B、C、D 四种状态中的一种或多种。现在要求保存 T 每一列每一行所有数据的状态,以便能被后续步骤继续使用,在保证数据完整性的同时,除了不能让结果数据过分膨胀,还要减少在后续步骤里的额外 io 开销/计算开销,避免影响正常业务……大概这就是既要马儿跑又要马儿不吃草😂?
一般对简单不是非常大量的数据,线上逻辑是干 json 就完了,因为数据量不大的缘故,几乎不会有什么序列化反序列化压缩解压缩的开销,并且 json 本身是肉眼可读数据,对所有人都方便。
大宽表里再用 json 显然是不可行的,业务内少说就有上百列、至少上千万行的状态要处理,可能除了代码比较简单好写以外,光序列化反序列化的性能损耗就够难受了。对这种时间/空间两难的境地,咱第一反应是自定义底层数据结构。然而对 Spark 这样固定封装了数据类型的框架来讲,改源码去定义新类型显然不是最优解。所以转念想有没有可能用已有数据类型,做一个自定义的抽象类型呢?翻了翻 Spark 数据类型相关的文档,所有基础原始类型都有固定长度、固定范围,嗯嗯很好,这些类型定义是统一的。再去翻了翻 SQL 层文档,看看有哪些可以操作原始类型的 SQL 函数,诶不错,左右位移,与或非异或之类也齐活了……
0. 定义数据结构
由于状态种类是固定的,对每条数据而言所有可能出现状态的情形也是固定了,下面就自然想到可以用类似 bitset 的方法去做状态的存储,利用率高,解析读取几乎没有计算量,相对来讲对时间空间都友好。还是 ABCD 四种状态为例,我们用 1 表示某状态存在,0 表示某状态不存在。比如说 0000
表示 ABCD 四种状态都没有,1000
表示只有 A,1001
表示有 A 和 D,0110
表示有 BC,1111
表示 ABCD 都有……
1. 定义存储方法
现在可以用 4 个比特位表示一个数据的状态,那么问题又来了,怎么才能把这些状态存入原有表中呢?Spark DataFrame 里显然没有四个比特位(半个字节)的基础类型,那我们拿个常用的整型 IntegerType
来用好啦,32 位的数据,存四个比特绰绰有余。下边拿 pyspark 举例说明:
>>> from pyspark.sql import (
... SparkSession,
... functions as F,
... types as T,
... )
# 创建 session
>>> spark = SparkSession.builder.appName('app').getOrCreate()
# 读入原表文件
>>> dataframe = spark.read.parquet('/path/to/file')
# 假定只有三列
>>> dataframe.columns
['column1', 'column2', 'column3']
# 假设计算第一列四种状态的 sql 如下,
# 每条 sql 返回 0 或 1,返回 1 表示某状态存在,0 则不存在
# column2, column3 其他两列以此类推
>>> column1_A_sql = F.when(..., 1).otherwise(0)
>>> column1_B_sql = F.when(..., 1).otherwise(0)
>>> column1_C_sql = F.when(..., 1).otherwise(0)
>>> column1_D_sql = F.when(..., 1).otherwise(0)
# 四个状态位合并到一个 IntegerType 整型中
>>> column1_states = F.lit(0) \
... .bitwiseOR(F.shiftLeft(columns1_A_sql, 3)) \
... .bitwiseOR(F.shiftLeft(columns1_B_sql, 2)) \
... .bitwiseOR(F.shiftLeft(columns1_C_sql, 1)) \
... .bitwiseOR(F.shiftLeft(columns1_D_sql, 0)) \
... .astype(T.IntegerType())
# 类推得到另两列的完整状态 column2_states, column3_states,
# 执行以上 sql 得到包含三列完整状态的 df `states`
>>> states = dataframe.select(
... column1_states.alias('column1_states')
... column2_states.alias('column2_states')
... column3_states.alias('column3_states')
... )
>>> states.show()
+--------------+--------------+--------------+
|column1_states|column2_states|column3_states|
+--------------+--------------+--------------+
| 1| 4| 7|
| 2| 5| 8|
| 3| 6| 9|
| ...| ...| ...|
+--------------+--------------+--------------+
2. 定义读取/解析方法
比特存储的标识位数据,读取起来也很方便。
# 每种标识位所在的顺序
>>> state_index = {'A': 3, 'B': 2, 'C': 1, 'D': 0}
# 定义从一列中读取某种标识状态的函数
>>> def get_state_x(states_dataframe, state_column, state_type):
... # 获取状态对应的顺序位,状态名不对时抛出错误
... index = state_index.get(state_type)
... if index is None:
... raise Exception(f'Unknown state type of {state_type}')
...
... # 过滤掉其他状态位
... masked_state = F.col(state_column).bitwiseAND(1 << index)
...
... # 判断剩下的状态位上是否还有 1,如果只剩 0,表明当前状态不存在,否则为存在
... result = F.when(masked_state == F.lit(0), F.lit(False)) \
... .otherwise(F.lit(True))
...
... # 对状态 df 执行 sql
... return states_dataframe.select(result.alias(f'{state_type}_exists'))
...
# 继续拿上边得到的 states 举例,解析 column1 中是否有 D 状态
>>> get_state_x(states, 'column1_states', 'D').show()
+--------+
|D_exists|
+--------+
| true|
| false|
| true|
| ...|
+--------+
到这里就完成了完整流程。
3. 优化改进
上边的实现方式有一个明显可改进之处 —— 使用 32 位的 IntegerType
存储 4 位的数据,只存了一列的状态,剩余 28 位浪费掉了,是不是可以在一个 32 位整型中最多塞入 8 列的状态,直接提升 7 倍的利用率呢?这思路是可行的,我们甚至还可以把整表的状态全塞入一列中
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from more_itertools import ichunked
spark = SparkSession.builder.appName('app').getOrCreate()
# 假设这次原表有 1000 列数据的状态要处理, column1, column2, column3, ...
table = spark.read.parquet('/path/to/table')
# 假设计算一列中某状态的 sql 函数如下,同之前一样,仍为返回 0 或 1 的 sql
def column_state_A(column):
...
def column_state_B(column):
...
def column_state_C(column):
...
def column_state_D(column):
...
# 获取全表状态的函数,一列有 ABCD 四种状态,八列一组,共 32 位,存入一个 32 位整型
# 最后把所有整数做成一个数组,存入一个数组列中
def get_full_states(table):
# 创建一个整数数组列,用来存入全表的所有状态
all_states_column = F.array().astype(T.ArrayType(T.IntegerType()))
# 每 8 列取一组,每组生成一个整型存入组内所有状态
for chunk in ichunked(enumerate(table.columns), 8):
state_int = F.lit(0).astype(T.IntegerType())
# 遍历组内每一列
for column_index, column in chunk:
# 计算本列的每种状态
state_A = column_state_A(column)
state_B = column_state_B(column)
state_C = column_state_C(column)
state_D = column_state_D(column)
# column_index 为本列在原表中序号,
# 计算本列在本组内的序号,因为是 8 列一组,所以可得组内序号如下
ingroup_index = column_index % 8
# 然后可得本列每个状态在组内的具体序号
a_index = ingroup_index * 4 + 3
b_index = ingroup_index * 4 + 2
c_index = ingroup_index * 4 + 1
d_index = ingroup_index * 4
# 写入本列的全部状态到组内
for state in [
F.shiftLeft(state_A, a_index),
F.shiftLeft(state_B, b_index),
F.shiftLeft(state_C, c_index),
F.shiftLeft(state_D, d_index)
]:
state_int = state_int.bitwiseOR(state)
# 一组的状态合并到全表
all_states_column = F.concat(
all_states_column,
F.array(state_int.astype(T.IntegerType()))
)
# 返回生成全表状态列的完整 sql
return all_states_column
# 生成所需 sql,并执行得到全表的状态列,放入表中,列名为 all_states
all_states_sql = get_full_states(table)
table = table.withColumn('all_states', all_states_sql)
预览看下最后生成的全表状态列,大概长这个样
>>> table.select(F.col('all_states')).show()
+-----------+
| all_states|
+-----------+
|[10 10 ...]|
| [0 0 ...]|
| [0 1 ...]|
+-----------+
核心步骤和改进前是一致的,只是多了缩减使用空间的步骤:将多列合并为一组,一组写入到同一个整数,最后将所有组的整数合并到一个数组列中。
接着读取状态的解析函数如下:
# 接上,生成的全表状态列为 all_states
# 同样,四种状态的相对顺序序号
state_index = {'A': 3, 'B': 2, 'C': 1, 'D': 0}
def get_state_x_from_column(table, all_states, column, state):
"""
计算某列中包含某状态的情况
:param table: 带有全表状态列的整表
:param all_states: 全表状态列的列名
:param column: 需要关注的列名
:param state: 需要关注的状态名
"""
columns = table.columns
columns.remove(all_states)
if column not in columns:
raise Exception(f'列名 {column} 不正确,表中没有此列')
if state not in state_index:
raise Exception(f'状态名 {state} 不正确,没有此状态')
# 使用本列在全表中的序号,计算本列所在组的组序号
group_index = columns.index(column) // 8
# 计算本列在组内的序号
column_ingroup_index = columns.index(column) % 8
# 计算所需状态在组内的序号
state_ingroup_index = state_index.get(state) + 4 * column_ingroup_index
# 根据以上序号过滤出所需状态的标识位
state_masked = F.col(all_states).getItem(group_index).bitwiseAND(1 << state_ingroup_index)
# 根据标识位返回 true/false 表示本列包含此状态的情况,返回完整 sql
return F.when(state_masked == 0, False).otherwise(True)
我们看下使用效果
>>> sql = get_state_x_from_column(table, 'all_states', 'column1', 'A')
>>> table.select(sql.alias('A_in_column1')).show()
+------------+
|A_in_column1|
+------------+
| true|
| false|
| false|
+------------+
Perfecto!
4. 其他
目前这段代码已经上线,实测生成全表状态占用的计算时间为任务总用时的 0.01% ~ 0.1%,解析某列状态占用的计算时间稳定在个位数毫秒,几乎可忽略不计。空间占用的增加量同所需状态量、数据表行列总数呈简单的线性关系,同表数据本身关系不大。
最后效果还不错,不过说到底还得归功于 Spark 这样的神级开源软件,默默在背后提供了如此完整又强大的功能,让我等小白能混口饭吃😂