Python – Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames

apache-sparkapache-spark-sqlpivotpysparkpython

I have some data in the following format (either RDD or Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

What I would like to do is to 'reshape' the data, convert certain rows in Country(specifically US, UK and CA) into columns:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

Essentially, I need something along the lines of Python's pivot workflow:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

My dataset is rather large so I can't really collect() and ingest the data into memory to do the reshaping in Python itself. Is there a way to convert Python's .pivot() into an invokable function while mapping either an RDD or a Spark DataFrame? Any help would be appreciated!

Best Solution

Since Spark 1.6 you can use pivot function on GroupedData and provide aggregate expression.

pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

Levels can be omitted but if provided can both boost performance and serve as an internal filter.

This method is still relatively slow but certainly beats manual passing data manually between JVM and Python.