Apache Spark 用比特位操作处理状态标志类数据的一种高效实践

前段时间遇到一个略有挑战性的数据需求,要求高速处理巨量的标志位数据。我费了些折腾功夫在 Spark 里用 SQL 的二进制操作给实现了个,但总觉得应该有更优雅的实现,有一种动作变形的感觉,虽然目前还挺堪用。

下边这是大致需求。有个大宽表 T(大约不超过 5k 列),对 T 的每条数据做某种计算可以得到它们的一些状态,状态的种类是固定的,比如这里的数据只能有 A、B、C、D 四种状态中的一种或多种。现在要求保存 T 每一列每一行所有数据的状态,以便能被后续步骤继续使用,在保证数据完整性的同时,除了不能让结果数据过分膨胀,还要减少在后续步骤里的额外 io 开销/计算开销,避免影响正常业务……大概这就是既要马儿跑又要马儿不吃草😂?

一般对简单不是非常大量的数据,线上逻辑是干 json 就完了,因为数据量不大的缘故,几乎不会有什么序列化反序列化压缩解压缩的开销,并且 json 本身是肉眼可读数据,对所有人都方便。

大宽表里再用 json 显然是不可行的,业务内少说就有上百列、至少上千万行的状态要处理,可能除了代码比较简单好写以外,光序列化反序列化的性能损耗就够难受了。对这种时间/空间两难的境地,咱第一反应是自定义底层数据结构。然而对 Spark 这样固定封装了数据类型的框架来讲,改源码去定义新类型显然不是最优解。所以转念想有没有可能用已有数据类型,做一个自定义的抽象类型呢?翻了翻 Spark 数据类型相关的文档,所有基础原始类型都有固定长度、固定范围,嗯嗯很好,这些类型定义是统一的。再去翻了翻 SQL 层文档,看看有哪些可以操作原始类型的 SQL 函数,诶不错,左右位移,与或非异或之类也齐活了……

0. 定义数据结构

由于状态种类是固定的,对每条数据而言所有可能出现状态的情形也是固定了,下面就自然想到可以用类似 bitset 的方法去做状态的存储,利用率高,解析读取几乎没有计算量,相对来讲对时间空间都友好。还是 ABCD 四种状态为例,我们用 1 表示某状态存在,0 表示某状态不存在。比如说 0000 表示 ABCD 四种状态都没有,1000 表示只有 A,1001 表示有 A 和 D,0110 表示有 BC,1111 表示 ABCD 都有……

1. 定义存储方法

现在可以用 4 个比特位表示一个数据的状态,那么问题又来了,怎么才能把这些状态存入原有表中呢?Spark DataFrame 里显然没有四个比特位(半个字节)的基础类型,那我们拿个常用的整型 IntegerType 来用好啦,32 位的数据,存四个比特绰绰有余。下边拿 pyspark 举例说明:

>>> from pyspark.sql import (
...     SparkSession,
...     functions as F,
...     types as T,
... )

# 创建 session
>>> spark = SparkSession.builder.appName('app').getOrCreate()

# 读入原表文件
>>> dataframe = spark.read.parquet('/path/to/file')

# 假定只有三列
>>> dataframe.columns
['column1', 'column2', 'column3']

# 假设计算第一列四种状态的 sql 如下,
# 每条 sql 返回 0 或 1,返回 1 表示某状态存在,0 则不存在
# column2, column3 其他两列以此类推
>>> column1_A_sql = F.when(..., 1).otherwise(0)
>>> column1_B_sql = F.when(..., 1).otherwise(0)
>>> column1_C_sql = F.when(..., 1).otherwise(0)
>>> column1_D_sql = F.when(..., 1).otherwise(0)

# 四个状态位合并到一个 IntegerType 整型中
>>> column1_states = F.lit(0) \
...     .bitwiseOR(F.shiftLeft(columns1_A_sql, 3)) \
...     .bitwiseOR(F.shiftLeft(columns1_B_sql, 2)) \
...     .bitwiseOR(F.shiftLeft(columns1_C_sql, 1)) \
...     .bitwiseOR(F.shiftLeft(columns1_D_sql, 0)) \
...     .astype(T.IntegerType())

# 类推得到另两列的完整状态 column2_states, column3_states,
# 执行以上 sql 得到包含三列完整状态的 df `states`
>>> states = dataframe.select(
...     column1_states.alias('column1_states')
...     column2_states.alias('column2_states')
...     column3_states.alias('column3_states')
... )
>>> states.show()
+--------------+--------------+--------------+
|column1_states|column2_states|column3_states|
+--------------+--------------+--------------+
|             1|             4|             7|
|             2|             5|             8|
|             3|             6|             9|
|           ...|           ...|           ...|
+--------------+--------------+--------------+

2. 定义读取/解析方法

比特存储的标识位数据,读取起来也很方便。

# 每种标识位所在的顺序
>>> state_index = {'A': 3, 'B': 2, 'C': 1, 'D': 0}

# 定义从一列中读取某种标识状态的函数
>>> def get_state_x(states_dataframe, state_column, state_type):
...     # 获取状态对应的顺序位,状态名不对时抛出错误
...     index = state_index.get(state_type)
...     if index is None:
...         raise Exception(f'Unknown state type of {state_type}')
...
...     # 过滤掉其他状态位
...     masked_state = F.col(state_column).bitwiseAND(1 << index)
...
...     # 判断剩下的状态位上是否还有 1,如果只剩 0,表明当前状态不存在,否则为存在
...     result = F.when(masked_state == F.lit(0), F.lit(False)) \
...         .otherwise(F.lit(True))
...
...     # 对状态 df 执行 sql
...     return states_dataframe.select(result.alias(f'{state_type}_exists'))
...

# 继续拿上边得到的 states 举例,解析 column1 中是否有 D 状态
>>> get_state_x(states, 'column1_states', 'D').show()
+--------+
|D_exists|
+--------+
|    true|
|   false|
|    true|
|     ...|
+--------+

到这里就完成了完整流程。

3. 优化改进

上边的实现方式有一个明显可改进之处 —— 使用 32 位的 IntegerType 存储 4 位的数据,只存了一列的状态,剩余 28 位浪费掉了,是不是可以在一个 32 位整型中最多塞入 8 列的状态,直接提升 7 倍的利用率呢?这思路是可行的,我们甚至还可以把整表的状态全塞入一列中

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from more_itertools import ichunked

spark = SparkSession.builder.appName('app').getOrCreate()

# 假设这次原表有 1000 列数据的状态要处理, column1, column2, column3, ...
table = spark.read.parquet('/path/to/table')

# 假设计算一列中某状态的 sql 函数如下,同之前一样,仍为返回 0 或 1 的 sql
def column_state_A(column):
    ...
    
def column_state_B(column):
    ...
    
def column_state_C(column):
    ...

def column_state_D(column):
    ...
    
# 获取全表状态的函数,一列有 ABCD 四种状态,八列一组,共 32 位,存入一个 32 位整型
# 最后把所有整数做成一个数组,存入一个数组列中
def get_full_states(table):
    # 创建一个整数数组列,用来存入全表的所有状态
    all_states_column = F.array().astype(T.ArrayType(T.IntegerType()))
    
    # 每 8 列取一组,每组生成一个整型存入组内所有状态
    for chunk in ichunked(enumerate(table.columns), 8):
        state_int = F.lit(0).astype(T.IntegerType())
        
        # 遍历组内每一列
        for column_index, column in chunk:
            # 计算本列的每种状态
            state_A = column_state_A(column)
            state_B = column_state_B(column)
            state_C = column_state_C(column)
            state_D = column_state_D(column)
            
            # column_index 为本列在原表中序号,
            # 计算本列在本组内的序号,因为是 8 列一组,所以可得组内序号如下
            ingroup_index = column_index % 8
            
            # 然后可得本列每个状态在组内的具体序号
            a_index = ingroup_index * 4 + 3
            b_index = ingroup_index * 4 + 2
            c_index = ingroup_index * 4 + 1
            d_index = ingroup_index * 4
            
            # 写入本列的全部状态到组内
            for state in [
                F.shiftLeft(state_A, a_index),
                F.shiftLeft(state_B, b_index),
                F.shiftLeft(state_C, c_index),
                F.shiftLeft(state_D, d_index)
            ]:
                state_int = state_int.bitwiseOR(state)
                
        # 一组的状态合并到全表
        all_states_column = F.concat(
            all_states_column, 
            F.array(state_int.astype(T.IntegerType()))
        )
        
    # 返回生成全表状态列的完整 sql
    return all_states_column

# 生成所需 sql,并执行得到全表的状态列,放入表中,列名为 all_states
all_states_sql = get_full_states(table)
table = table.withColumn('all_states', all_states_sql)

预览看下最后生成的全表状态列,大概长这个样

>>> table.select(F.col('all_states')).show()
+-----------+
| all_states|
+-----------+
|[10 10 ...]|
|  [0 0 ...]|
|  [0 1 ...]|
+-----------+

核心步骤和改进前是一致的,只是多了缩减使用空间的步骤:将多列合并为一组,一组写入到同一个整数,最后将所有组的整数合并到一个数组列中。

接着读取状态的解析函数如下:

# 接上,生成的全表状态列为 all_states
# 同样,四种状态的相对顺序序号
state_index = {'A': 3, 'B': 2, 'C': 1, 'D': 0}

def get_state_x_from_column(table, all_states, column, state):
    """
    计算某列中包含某状态的情况
    
    :param table:        带有全表状态列的整表
    :param all_states:   全表状态列的列名
    :param column:       需要关注的列名
    :param state:        需要关注的状态名
    """
    columns = table.columns
    columns.remove(all_states)
    if column not in columns:
        raise Exception(f'列名 {column} 不正确,表中没有此列')
        
    if state not in state_index:
        raise Exception(f'状态名 {state} 不正确,没有此状态')
        
    # 使用本列在全表中的序号,计算本列所在组的组序号
    group_index = columns.index(column) // 8
    # 计算本列在组内的序号
    column_ingroup_index = columns.index(column) % 8
    # 计算所需状态在组内的序号
    state_ingroup_index = state_index.get(state) + 4 * column_ingroup_index
    
    # 根据以上序号过滤出所需状态的标识位
    state_masked = F.col(all_states).getItem(group_index).bitwiseAND(1 << state_ingroup_index)
    
    # 根据标识位返回 true/false 表示本列包含此状态的情况,返回完整 sql
    return F.when(state_masked == 0, False).otherwise(True)

我们看下使用效果

>>> sql = get_state_x_from_column(table, 'all_states', 'column1', 'A')
>>> table.select(sql.alias('A_in_column1')).show()
+------------+
|A_in_column1|
+------------+
|        true|
|       false|
|       false|
+------------+

Perfecto!

4. 其他

目前这段代码已经上线,实测生成全表状态占用的计算时间为任务总用时的 0.01% ~ 0.1%,解析某列状态占用的计算时间稳定在个位数毫秒,几乎可忽略不计。空间占用的增加量同所需状态量、数据表行列总数呈简单的线性关系,同表数据本身关系不大。

最后效果还不错,不过说到底还得归功于 Spark 这样的神级开源软件,默默在背后提供了如此完整又强大的功能,让我等小白能混口饭吃😂