PySpark
Install
INFO
make sure pyspark version is same as spark version
bash
brew install apache-spark
pip3 install pyspark
Resources
Add-Ons
Dummy dataframe
python
# https://stackoverflow.com/a/57960267/19652796
df = spark.createDataFrame(
[
(1, "foo"), # create your data here, be consistent in the types.
(2, "bar"),
],
["id", "label"], # add your column names here
)
Init
python
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.functions import coalesce
from pyspark.sql.functions import col
from pyspark.sql.functions import lit
from pyspark.sql.functions import when
from pyspark.sql.types import DoubleType
from pyspark.sql.types import IntegerType
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType
from pyspark.sql.types import TimestampType
from pyspark.sql.window import Window
spark = (
SparkSession.builder.appName("Pyspark playground")
.config("spark.hadoop.fs.s3a.access.key", KEY)
.config("spark.hadoop.fs.s3a.secret.key", SECRET)
.config("spark.executor.memory", "4g")
.config("spark.driver.memory", "4g")
.config("spark.jars.packages", "")
.getOrCreate()
)
# set config after spark session is created
spark_session.conf.set("spark.executor.memory", "8g")
spark.sparkContext.setLogLevel("ERROR")
spark.sparkContext.setCheckpointDir("checkpoint") # [DEBUG]
I/O
python
# CSV / TSV
project = spark.read.csv(
project_file,
header="true",
sep="\t",
)
spark.write.csv(output_path, header=True)
# JSON
spark.read.json(
"data/DMP_HIVE/all_listing.json"
) # add .option("multiLine", True) for multi-line
spark.write.json(OUTPATH, compression="gzip")
DataFrame
python
# metadata
df.printSchema() # or .columns
# select columns
df.select(["a", "b", "c"])
# sampling
df.sample(False, sampling_percentage, seed=0)
# count records
df.count()
# conversions
df.toPandas() # spark to pandas
spark_session.createDataFrame(df) # pandas to spark
# show in vertical
df.show(n=3, truncate=False, vertical=True)
# get schema in JSON
schema = df.schema.jsonValue()
schema = schema["fields"]
Transformations
python
# rename columns
df.withColumnRenamed("old_name", "new_name")
# add null column
df.withColumn(col_name, F.lit(None).cast(col_type))
# dtype casting
df.withColumn("col_name", df["col_name"].cast(IntegerType()))
# combine values from multiple rows via groupby
df.groupBy(groupby_col).agg(F.collect_list(col_name))
# select elem by name from array column
F.col(col_name)["elem_key"]
# select elem by name from array column - by index
F.col(col_name).getItem(0)
# find median
df.approxQuantile(df.columns, [0.5], 0.25)
# get percentile
df.approxQuantile(["Apple", "Oranges"], [0.1, 0.25, 0.5, 0.75, 0.9, 0.95], 0.1)
# get median during groupby
# https://stackoverflow.com/a/71735997
df.groupBy("Id").agg(F.percentile_approx("value", 0.5).alias("median_approximate"))
# join
df.join(
df2,
[
key
], # df.key == df2.key in case keys are different, otherwise [COL_NAME] to prevent column duplicates
how="left",
)
functions
python
# combine cols to array
F.array("x_1", "x_2")
# fillna with another column
F.coalesce("a", "b")
# create new column with max value from set of columns
F.greatest(a["one"], a["two"], a["three"])
# regex matching --> longest maching works if longest regex is at the start
F.regexp_replace(trim(lower(col(col_name))), regex_str, "")
# explode array
df.withColumn("tmp", F.explode("tmp")).select(
*df.columns, col("tmp.a"), col("tmp.b"), col("tmp.c")
)
# convert to JSON
F.to_json(c)
# convert to list
df.select("mvv").rdd.flatMap(lambda x: x).collect()
# udf
from pyspark.sql.functions import udf
import pyspark.sql.functions as F
from pyspark.sql.types import StringType
@udf(returnType=StringType())
def object_id_to_date(object_id: str):
object_id = ObjectId(object_id)
return str(object_id.generation_time.date())
df.select(object_id_to_date("_id").alias("creation_date"))
datetime
python
# epoch to timestamp
F.from_unixtime(df.column / 1000)
# timestamp to date
F.to_date("listing_update")
# utz to to tz
F.from_utc_timestamp("datetime_utc", "CST")
# filter between date range
df.filter(
F.col("date_col").between(
F.date_add(F.current_date(), -7), F.date_add(F.current_date(), -1)
)
)
SQL
python
df.createOrReplaceTempView("foo")
spark.sql(
"""
SELECT *
FROM foo
"""
)
Optimization
Caching
INFO
Improves read performance for frequently accessed DataFrame
python
df.cache()
# clear cache
spark.catalog.clearCache()
Repartition + partition data
python
df.repartition(4)
df.write.partitionBy(*partition_columns).parquet(base_path, mode=write_mode)
Dynamic partition write mode
python
"""
Note: Why do we have to change partitionOverwriteMode?
Without config partitionOverwriteMode = 'dynamic', Spark will
overwrite all partitions in hierarchy with the new
data we are writing. That's undesirable and dangerous.
https://stackoverflow.com/questions/42317738/how-to-partition-and-write-dataframe-in-spark-without-deleting-partitions-with-n
Therefore, we will temporarily use 'dynamic' within the context of writing files to storage.
"""
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
Skew join optimization
https://stackoverflow.com/a/57951114
python
from datetime import datetime
from math import exp
def count_elements(splitIndex, iterator):
n = sum(1 for _ in iterator)
yield (splitIndex, n)
def get_part_index(splitIndex, iterator):
for it in iterator:
yield (splitIndex, it)
num_parts = 18
# create the large skewed rdd
skewed_large_rdd = sc.parallelize(range(0, num_parts), num_parts).flatMap(
lambda x: range(0, int(exp(x)))
)
skewed_large_rdd = skewed_large_rdd.mapPartitionsWithIndex(
lambda ind, x: get_part_index(ind, x)
)
skewed_large_df = spark.createDataFrame(skewed_large_rdd, ["x", "y"])
small_rdd = sc.parallelize(range(0, num_parts), num_parts).map(lambda x: (x, x))
small_df = spark.createDataFrame(small_rdd, ["a", "b"])
## prep salts
salt_bins = 100
from pyspark.sql import functions as F
skewed_transformed_df = skewed_large_df.withColumn(
"salt", (F.rand() * salt_bins).cast("int")
).cache()
small_transformed_df = small_df.withColumn(
"replicate", F.array([F.lit(i) for i in range(salt_bins)])
)
small_transformed_df = (
small_transformed_df.select("*", F.explode("replicate").alias("salt"))
.drop("replicate")
.cache()
)
## magic happens here
t0 = datetime.now()
result2 = skewed_transformed_df.join(
small_transformed_df,
(skewed_transformed_df["x"] == small_transformed_df["a"])
& (skewed_transformed_df["salt"] == small_transformed_df["salt"]),
)
result2.count()
print("The direct join takes %s" % (str(datetime.now() - t0)))
JDBC
INFO
To overwrite without losing schema & permission, use truncate
Postgres
python
spark = (
SparkSession.builder.appName("Pyspark playground")
.config("spark.executor.memory", "16g")
.config("spark.driver.memory", "16g")
.config("spark.jars.packages", "org.postgresql:postgresql:42.5.0")
.getOrCreate()
)
uri = "jdbc:postgresql://host.docker.internal:5432/postgres"
### read
(
spark.read.format("jdbc")
.option("url", uri)
.option("dbtable", TABLENAME)
.option("user", USERNAME)
.option("password", PASSWORD)
.option("driver", "org.postgresql.Driver")
.load()
)
### write
(
df.write.format("jdbc")
.option("url", uri)
.option("dbtable", TABLENAME)
.option("user", USERNAME)
.option("password", PASSWORD)
.option("driver", "org.postgresql.Driver")
.option("truncate", "true")
.mode("overwrite")
.save()
)
MongoDB
python
import os
os.environ[
"PYSPARK_SUBMIT_ARGS"
] = '--packages "org.mongodb.spark:mongo-spark-connector_2.11:2.4.2" pyspark-shell'
from pyspark.sql import SparkSession
### read
(
spark.read.format("mongodb")
.option("connection.uri", uri)
.option("database", DBNAME)
.option("collection", TABLENAME)
.load()
)
### write
(
df.write.format("mongodb")
.option("connection.uri", uri)
.option("database", DBNAME)
.option("collection", TABLENAME)
.mode("overwrite")
.save()
)
Clickhouse
python
spark = (
SparkSession.builder.appName("Pyspark playground")
.config("spark.executor.memory", "16g")
.config("spark.driver.memory", "16g")
.config("spark.jars.packages", "com.github.housepower:clickhouse-spark-runtime-3.4_2.12:0.7.3,com.clickhouse:clickhouse-jdbc:0.6.2,org.apache.httpcomponents.client5:httpclient5:5.3.1")
.getOrCreate()
)
uri = "jdbc:clickhouse://localhost:8123/clickhouse"
### read
(
spark.read.format('jdbc')
.option('url', uri)
.option('dbtable',"nyc_taxi")
.option('user',"clickhouse")
.option('password',"clickhousepassword")
.option('driver',"com.clickhouse.jdbc.ClickHouseDriver")
.load()
)
### write
(
df.write.format("jdbc")
.option("url", uri)
.option("dbtable", "nyc_taxi") # [TODO] change me
.option("user", "clickhouse")
.option("password", "clickhousepassword")
.option("driver", "com.clickhouse.jdbc.ClickHouseDriver")
.option("createTableOptions", "engine=MergeTree() order by tpep_pickup_datetime")
.option("truncate", "true")
.option("numPartitions", 6)
.mode("overwrite")
.save()
)
spark-submit
bash
spark-submit --conf spark.driver.memory=25gb --executor-memory 13g --num-executors 50 --driver-memory 20g FILE.PY
spark-submit --conf maximizeResourceAllocation=true FILE.PY
### AWS EMR
# also change checkpoint dir to "mnt/checkpoint"
spark-submit --py-files dist-matrix-module.zip property_distance_matrix.py
# alternative
spark-submit --deploy-mode cluster s3://<PATH TO FILE>/sparky.py
Misc
bash
# get spark location
echo 'sc.getConf.get("spark.home")' | spark-shell
JARs
AWS
python
from pyspark.sql import SparkSession
def init_spark() -> SparkSession:
config = (
SparkSession.builder()
.config(
"spark.jars.packages",
"org.apache.hadoop:hadoop-aws:3.3.4,"
"com.amazonaws:aws-java-sdk-core:1.12.725,"
"com.amazonaws:aws-java-sdk-dynamodb:1.12.725,"
"com.amazonaws:aws-java-sdk-s3:1.12.725,"
"com.amazonaws:aws-java-sdk:1.12.725",
)
.config(
"fs.s3a.aws.credentials.provider",
"org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider",
)
.config("spark.executor.memory", "8g")
.config("spark.driver.memory", "8g")
.getOrCreate()
)
Extensions
Sedona
python
from sedona.spark import SedonaContext
config = (
SedonaContext.builder()
.config(
"spark.jars.packages",
"org.apache.sedona:sedona-spark-shaded-3.4_2.12:1.4.1,"
"org.datasyslab:geotools-wrapper:1.4.0-28.2",
)
.getOrCreate()
)
sedona = SedonaContext.create(config)
## read geoparquet
df = sedona.read.format("geoparquet").load("data/example.parquet")
## spatial query
df.createOrReplaceTempView("df")
sedona.sql(
"""
SELECT *, ST_GeoHash(geometry, 5) as geohash
FROM df
ORDER BY geohash
"""
).show()
## polygon from point
ST_Intersects(ST_Point(l.longitude, l.latitude), ST_GeomFromWKT(geometry))
## find surrounding points within x radius
ST_Distance( ST_Point(df.LON, df.LAT), ST_Point(poi.longitude, poi.latitude) ) <= 100/1000/111.319 -- 100 meter in degrees
Cookbook
Generate fake data
python
import uuid
from faker import Faker
from pyspark import SparkContext
from tqdm import tqdm
from utils.create_spark_session import get_spark_session
### config
N = 500000
### init spark
spark = get_spark_session()
### main
fake = Faker()
schema = {
"id": uuid.uuid4,
"name": fake.name,
"company": fake.company,
"address": fake.address,
"latitude": fake.latitude,
"longitude": fake.longitude,
"phone_number": fake.phone_number,
"created_at": fake.date_time,
"updated_at": fake.date_time,
}
values = [
tuple(str(i()) if i == uuid.uuid4 else i() for i in schema.values())
for _ in tqdm(range(N))
]
### generate seed data
sc = SparkContext.getOrCreate()
rdd = sc.parallelize(values)
# Create a DataFrame from the RDD
df = spark.createDataFrame(rdd, list(schema.keys()))
df.repartition(4).write.parquet("data/seed_data", mode="overwrite")
Count missing values + groupby
python
df.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns])
Filter by order
python
from pyspark.sql.window import Window
w = Window().partitionBy(partition_col).orderBy(F.desc(order_by_key))
(
df.withColumn("rank", F.row_number().over(w))
.filter(col("rank") == 1)
.drop(col("rank"))
)
Sum columns horizontally
python
t = spark.createDataFrame(
[
(1, 2, 4, 0),
(2, 5, 4, None),
(2, 6, 4, 1),
],
["id", "a", "b", "c"], # add your column names here
)
t.withColumn(
"count",
sum(when(t[col].isNull(), F.lit(0)).otherwise(t[col]) for col in ["a", "b", "c"]),
).show()
# OR, this one uses `col` instead of `df`
t.withColumn(
"count",
sum(when(col(i).isNull(), F.lit(0)).otherwise(col(i)) for i in ["a", "b", "c"]),
).show()
Output
+---+---+---+----+-----+
| id| a| b| c|count|
+---+---+---+----+-----+
| 1| 2| 4| 0| 6|
| 2| 5| 4|null| 9|
| 2| 6| 4| 1| 11|
+---+---+---+----+-----+