PySpark
Install
INFO
make sure pyspark version is same as spark version
bash
brew install apache-spark
pip3 install pyspark
Resources
Add-Ons
Dummy dataframe
python
# https://stackoverflow.com/a/57960267/19652796
df = spark.createDataFrame(
[
(1, "foo"), # create your data here, be consistent in the types.
(2, "bar"),
],
["id", "label"], # add your column names here
)
Init
python
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.functions import coalesce, col, lit, when
from pyspark.sql.types import (
DoubleType,
IntegerType,
StringType,
StructField,
StructType,
TimestampType,
)
from pyspark.sql.window import Window
spark = (
SparkSession.builder.appName("Pyspark playground")
.config("spark.hadoop.fs.s3a.access.key", KEY)
.config("spark.hadoop.fs.s3a.secret.key", SECRET)
.config("spark.executor.memory", "4g")
.config("spark.driver.memory", "4g")
.config("spark.jars.packages", "")
.getOrCreate()
)
# set config after spark session is created
spark_session.conf.set("spark.executor.memory", "8g")
spark.sparkContext.setLogLevel("ERROR")
spark.sparkContext.setCheckpointDir("checkpoint") # [DEBUG]
I/O
python
# CSV / TSV
project = spark.read.csv(
project_file,
header="true",
sep="\t",
)
spark.write.csv(output_path, header=True)
# JSON
spark.read.json(
"data/DMP_HIVE/all_listing.json"
) # add .option("multiLine", True) for multi-line
spark.write.json(OUTPATH, compression="gzip")
DataFrame
python
# metadata
df.printSchema() # or .columns
# select columns
df.select(["a", "b", "c"])
# sampling
df.sample(False, sampling_percentage, seed=0)
# count records
df.count()
# conversions
df.toPandas() # spark to pandas
spark_session.createDataFrame(df) # pandas to spark
# show in vertical
df.show(n=3, truncate=False, vertical=True)
# get schema in JSON
schema = df.schema.jsonValue()
schema = schema["fields"]
Transformations
python
# rename columns
df.withColumnRenamed("old_name", "new_name")
# add null column
df.withColumn(col_name, F.lit(None).cast(col_type))
# dtype casting
df.withColumn("col_name", df["col_name"].cast(IntegerType()))
# combine values from multiple rows via groupby
df.groupBy(groupby_col).agg(F.collect_list(col_name))
# select elem by name from array column
F.col(col_name)["elem_key"]
# select elem by name from array column - by index
F.col(col_name).getItem(0)
# find median
df.approxQuantile(df.columns, [0.5], 0.25)
# get percentile
df.approxQuantile(["Apple", "Oranges"], [0.1, 0.25, 0.5, 0.75, 0.9, 0.95], 0.1)
# get median during groupby
# https://stackoverflow.com/a/71735997
df.groupBy("Id").agg(F.percentile_approx("value", 0.5).alias("median_approximate"))
# join
df.join(
df2,
[
key
], # df.key == df2.key in case keys are different, otherwise [COL_NAME] to prevent column duplicates
how="left",
)
functions
python
# combine cols to array
F.array("x_1", "x_2")
# fillna with another column
F.coalesce("a", "b")
# create new column with max value from set of columns
F.greatest(a["one"], a["two"], a["three"])
# regex matching --> longest maching works if longest regex is at the start
F.regexp_replace(trim(lower(col(col_name))), regex_str, "")
# explode array
## also works: explode_outer
df.withColumn("tmp", F.explode("tmp")).select(
*df.columns, col("tmp.a"), col("tmp.b"), col("tmp.c")
)
# convert to JSON
F.to_json(c)
# convert to list
df.select("mvv").rdd.flatMap(lambda x: x).collect()
import pyspark.sql.functions as F
# udf
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
@udf(returnType=StringType())
def object_id_to_date(object_id: str):
object_id = ObjectId(object_id)
return str(object_id.generation_time.date())
df.select(object_id_to_date("_id").alias("creation_date"))
datetime
python
# epoch to timestamp
F.from_unixtime(df.column / 1000)
# timestamp to date
F.to_date("listing_update")
# utz to to tz
F.from_utc_timestamp("datetime_utc", "CST")
# filter between date range
df.filter(
F.col("date_col").between(
F.date_add(F.current_date(), -7), F.date_add(F.current_date(), -1)
)
)
SQL
python
df.createOrReplaceTempView("foo")
spark.sql(
"""
SELECT *
FROM foo
"""
)
Optimization
Caching
INFO
Improves read performance for frequently accessed DataFrame
python
df.cache()
# clear cache
spark.catalog.clearCache()
Repartition + partition data
python
df.repartition(4)
df.write.partitionBy(*partition_columns).parquet(base_path, mode=write_mode)
Dynamic partition write mode
python
"""
Note: Why do we have to change partitionOverwriteMode?
Without config partitionOverwriteMode = 'dynamic', Spark will
overwrite all partitions in hierarchy with the new
data we are writing. That's undesirable and dangerous.
https://stackoverflow.com/questions/42317738/how-to-partition-and-write-dataframe-in-spark-without-deleting-partitions-with-n
Therefore, we will temporarily use 'dynamic' within the context of writing files to storage.
"""
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
Skew join optimization
https://stackoverflow.com/a/57951114
python
from datetime import datetime
from math import exp
def count_elements(splitIndex, iterator):
n = sum(1 for _ in iterator)
yield (splitIndex, n)
def get_part_index(splitIndex, iterator):
for it in iterator:
yield (splitIndex, it)
num_parts = 18
# create the large skewed rdd
skewed_large_rdd = sc.parallelize(range(0, num_parts), num_parts).flatMap(
lambda x: range(0, int(exp(x)))
)
skewed_large_rdd = skewed_large_rdd.mapPartitionsWithIndex(
lambda ind, x: get_part_index(ind, x)
)
skewed_large_df = spark.createDataFrame(skewed_large_rdd, ["x", "y"])
small_rdd = sc.parallelize(range(0, num_parts), num_parts).map(lambda x: (x, x))
small_df = spark.createDataFrame(small_rdd, ["a", "b"])
## prep salts
salt_bins = 100
from pyspark.sql import functions as F
skewed_transformed_df = skewed_large_df.withColumn(
"salt", (F.rand() * salt_bins).cast("int")
).cache()
small_transformed_df = small_df.withColumn(
"replicate", F.array([F.lit(i) for i in range(salt_bins)])
)
small_transformed_df = (
small_transformed_df.select("*", F.explode("replicate").alias("salt"))
.drop("replicate")
.cache()
)
## magic happens here
t0 = datetime.now()
result2 = skewed_transformed_df.join(
small_transformed_df,
(skewed_transformed_df["x"] == small_transformed_df["a"])
& (skewed_transformed_df["salt"] == small_transformed_df["salt"]),
)
result2.count()
print("The direct join takes %s" % (str(datetime.now() - t0)))
JDBC
INFO
To overwrite without losing schema & permission, use truncate
Postgres
python
spark = (
SparkSession.builder.appName("Pyspark playground")
.config("spark.executor.memory", "16g")
.config("spark.driver.memory", "16g")
.config("spark.jars.packages", "org.postgresql:postgresql:42.5.0")
.getOrCreate()
)
uri = "jdbc:postgresql://host.docker.internal:5432/postgres"
### read
(
spark.read.format("jdbc")
.option("url", uri)
.option("dbtable", TABLENAME)
.option("user", USERNAME)
.option("password", PASSWORD)
.option("driver", "org.postgresql.Driver")
.load()
)
### write
(
df.write.format("jdbc")
.option("url", uri)
.option("dbtable", TABLENAME)
.option("user", USERNAME)
.option("password", PASSWORD)
.option("driver", "org.postgresql.Driver")
.option("truncate", "true")
.mode("overwrite")
.save()
)
MongoDB
python
import os
os.environ["PYSPARK_SUBMIT_ARGS"] = (
'--packages "org.mongodb.spark:mongo-spark-connector_2.11:2.4.2" pyspark-shell'
)
from pyspark.sql import SparkSession
### read
(
spark.read.format("mongodb")
.option("connection.uri", uri)
.option("database", DBNAME)
.option("collection", TABLENAME)
.load()
)
### write
(
df.write.format("mongodb")
.option("connection.uri", uri)
.option("database", DBNAME)
.option("collection", TABLENAME)
.mode("overwrite")
.save()
)
Clickhouse
python
spark = (
SparkSession.builder.appName("Pyspark playground")
.config("spark.executor.memory", "16g")
.config("spark.driver.memory", "16g")
.config(
"spark.jars.packages",
"com.github.housepower:clickhouse-spark-runtime-3.4_2.12:0.7.3,com.clickhouse:clickhouse-jdbc:0.6.2,org.apache.httpcomponents.client5:httpclient5:5.3.1",
)
.getOrCreate()
)
uri = "jdbc:clickhouse://localhost:8123/clickhouse"
### read
(
spark.read.format("jdbc")
.option("url", uri)
.option("dbtable", "nyc_taxi")
.option("user", "clickhouse")
.option("password", "clickhousepassword")
.option("driver", "com.clickhouse.jdbc.ClickHouseDriver")
.load()
)
### write
(
df.write.format("jdbc")
.option("url", uri)
.option("dbtable", "nyc_taxi") # [TODO] change me
.option("user", "clickhouse")
.option("password", "clickhousepassword")
.option("driver", "com.clickhouse.jdbc.ClickHouseDriver")
.option("createTableOptions", "engine=MergeTree() order by tpep_pickup_datetime")
.option("truncate", "true")
.option("numPartitions", 6)
.mode("overwrite")
.save()
)
spark-submit
bash
spark-submit --conf spark.driver.memory=25gb --executor-memory 13g --num-executors 50 --driver-memory 20g FILE.PY
spark-submit --conf maximizeResourceAllocation=true FILE.PY
### AWS EMR
# also change checkpoint dir to "mnt/checkpoint"
spark-submit --py-files dist-matrix-module.zip property_distance_matrix.py
# alternative
spark-submit --deploy-mode cluster s3://<PATH TO FILE>/sparky.py
Misc
bash
# get spark location
echo 'sc.getConf.get("spark.home")' | spark-shell
JARs
AWS
python
from pyspark.sql import SparkSession
def init_spark() -> SparkSession:
config = (
SparkSession.builder()
.config(
"spark.jars.packages",
"org.apache.hadoop:hadoop-aws:3.3.4,"
"com.amazonaws:aws-java-sdk-core:1.12.725,"
"com.amazonaws:aws-java-sdk-dynamodb:1.12.725,"
"com.amazonaws:aws-java-sdk-s3:1.12.725,"
"com.amazonaws:aws-java-sdk:1.12.725",
)
.config(
"fs.s3a.aws.credentials.provider",
"org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider",
)
.config("spark.executor.memory", "8g")
.config("spark.driver.memory", "8g")
.getOrCreate()
)
Extensions
Sedona
python
from sedona.spark import SedonaContext
config = (
SedonaContext.builder()
.config(
"spark.jars.packages",
"org.apache.sedona:sedona-spark-shaded-3.4_2.12:1.4.1,"
"org.datasyslab:geotools-wrapper:1.4.0-28.2",
)
.getOrCreate()
)
sedona = SedonaContext.create(config)
## read geoparquet
df = sedona.read.format("geoparquet").load("data/example.parquet")
## spatial query
df.createOrReplaceTempView("df")
sedona.sql(
"""
SELECT *, ST_GeoHash(geometry, 5) as geohash
FROM df
ORDER BY geohash
"""
).show()
## polygon from point
ST_Intersects(ST_Point(l.longitude, l.latitude), ST_GeomFromWKT(geometry))
## find surrounding points within x radius
ST_Distance( ST_Point(df.LON, df.LAT), ST_Point(poi.longitude, poi.latitude) ) <= 100/1000/111.319 -- 100 meter in degrees
Cookbook
Generate fake data
python
import uuid
from faker import Faker
from pyspark import SparkContext
from tqdm import tqdm
from utils.create_spark_session import get_spark_session
### config
N = 500000
### init spark
spark = get_spark_session()
### main
fake = Faker()
schema = {
"id": uuid.uuid4,
"name": fake.name,
"company": fake.company,
"address": fake.address,
"latitude": fake.latitude,
"longitude": fake.longitude,
"phone_number": fake.phone_number,
"created_at": fake.date_time,
"updated_at": fake.date_time,
}
values = [
tuple(str(i()) if i == uuid.uuid4 else i() for i in schema.values())
for _ in tqdm(range(N))
]
### generate seed data
sc = SparkContext.getOrCreate()
rdd = sc.parallelize(values)
# Create a DataFrame from the RDD
df = spark.createDataFrame(rdd, list(schema.keys()))
df.repartition(4).write.parquet("data/seed_data", mode="overwrite")
Count missing values + groupby
python
df.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns])
Filter by order
python
from pyspark.sql.window import Window
w = Window().partitionBy(partition_col).orderBy(F.desc(order_by_key))
(
df.withColumn("rank", F.row_number().over(w))
.filter(col("rank") == 1)
.drop(col("rank"))
)
Sum columns horizontally
python
t = spark.createDataFrame(
[
(1, 2, 4, 0),
(2, 5, 4, None),
(2, 6, 4, 1),
],
["id", "a", "b", "c"], # add your column names here
)
t.withColumn(
"count",
sum(when(t[col].isNull(), F.lit(0)).otherwise(t[col]) for col in ["a", "b", "c"]),
).show()
# OR, this one uses `col` instead of `df`
t.withColumn(
"count",
sum(when(col(i).isNull(), F.lit(0)).otherwise(col(i)) for i in ["a", "b", "c"]),
).show()
Output
+---+---+---+----+-----+
| id| a| b| c|count|
+---+---+---+----+-----+
| 1| 2| 4| 0| 6|
| 2| 5| 4|null| 9|
| 2| 6| 4| 1| 11|
+---+---+---+----+-----+