Working Around Missing Array Functions in PySpark

Recently I noticed that the ArrayType in PySpark is missing some useful aggregation functions.

Lets suppose you have a data frame created as follows:

# we will be using these imports later
from pyspark.sql import functions as fn, types as T
from pyspark.sql import Window
import pandas as pd
# import panda_udf difect so we can use it as a decorator
from pyspark.sql.functions import pandas_udf, PandasUDFType
df_array = spark.createDataFrame(
  [
    (1, [1,2,3]),
    (2, [4,5,6]),
    (3, [7,8,9]),
    (1, [2,2,2]),
    (2, [5,5,5]),
    (3, [8,8,8])
  ],
  ("group", "array")
)

# this only sums the group column, not the array column
display(df_array.groupBy().sum())

If you try to groupBy or window a numeric column you can easily apply aggregation functions like sum or average. However if you try to apply these functions to a column containing arrays then you will find that the PySpark sum and average functions do not work on these arrays in the element wise fashion you might expect. This can be seen with the code above where only the group column is summed, not the array column.

On the other hand PandasĀ  aggregation functions handle arrays in just the way you might expect. This means that there is a simple workaround for applying aggregation functions to pyspark dataframes. we can use PySpark’s pandas_udf function.

Note that the way panda_udf defines types is changing. In the past it was done via pandasUDFType. However in PySpark 3 we now use type hints. More information can be found here.

Lets use the new method to create a couple of aggregation functions for our dataframe

@pandas_udf(T.ArrayType(T.IntegerType()))
def sum_array(input: pd.Series) -> float:
  # now we can use pandas sum which does handle arrays
  return input.sum()

@pandas_udf(T.ArrayType(T.FloatType()))
def avg_array(input: pd.Series) -> float:
  # now we can use pandas mean which does handle arrays
  return input.mean()

window = Window.partitionBy(
  'group'
).rowsBetween(
  Window.unboundedPreceding,
  Window.unboundedFollowing
)

df_out = df_array.withColumn(
  'sum_array', sum_array('array').over(window)
).withColumn(
  'avg_array', avg_array('array').over(window)
)

display(df_out)

And there you go, we have used aggregation functions elementwise on arrays in PySpark. From here we could use the row wise array functions which do exist in PySpark to perform operations row wise between a row’s values and the overall sums or averages.

An example notebook with this code is available on Databricks Community edition if you want to look at it. Or on Github.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.