Spark is delightful for Big Data analysis. It allows using very high-level code to perform a large variety of operations. It also supports SQL, so you don’t need to learn a lot of new stuff to start being productive in Spark (of course assuming that you have some knowledge of SQL).

However, if you want to use Spark more efficiently, you need to learn a lot of concepts, especially about data partitioning, relations between partitions (narrow dependencies vs. wide dependencies), and shuffling. I can recommend the https://www.coursera.org/learn/scala-spark-big-data (it covers a lot of fundamental concepts like RDD and so on), or you can find some materials covering those topics all over the internet.

Nevertheless, in this post, I’ll try to describe some ideas how to look under the hood of the Spark operations related to DataFrames.

Data preparation.

Let’s start with preparing some data. I’ll use a small sample from the immortal iris dataset. This is done in R:

path <- "/tmp/iris.csv"
ir <- iris[c(1:3, 51:53, 101:103), c(1,5)]
colnames(ir) <- c("sep_len", "species")
write.csv(ir, file = path, row.names = FALSE)

Print the data:

cat /tmp/iris.csv
## "sep_len","species"
## 5.1,"setosa"
## 4.9,"setosa"
## 4.7,"setosa"
## 7,"versicolor"
## 6.4,"versicolor"
## 6.9,"versicolor"
## 6.3,"virginica"
## 5.8,"virginica"
## 7.1,"virginica"

Or there’s an excellent column function in Linux:

column -s, -t < /tmp/iris.csv
## "sep_len"  "species"
## 5.1        "setosa"
## 4.9        "setosa"
## 4.7        "setosa"
## 7          "versicolor"
## 6.4        "versicolor"
## 6.9        "versicolor"
## 6.3        "virginica"
## 5.8        "virginica"
## 7.1        "virginica"

Setup Spark engine for knitr.

I’m using my Scala engine from the previous post (https://www.zstat.pl/2018/07/27/scala-in-knitr/) to render the outputs from Scala code. The code below prepares everything to run Scala/Spark session inside R’s knitr:

library(knitr)
library(rscala)
library(knitr)
# ... args passed to rscala::scala functions. See ?rscala::scala for more informations.
make_scala_engine <- function(...) {
  
  rscala::scala(assign.name = "engine", serialize.output = TRUE, stdout = "", ...)
  engine <- force(engine)
  function(options) {
    code <- paste(options$code, collapse = "\n")
    output <- capture.output(invisible(engine + code))
    engine_output(options, options$code, output)
  }
}

jars <- dir("~/spark/spark-2.3.0-bin-hadoop2.7/jars/", full.names = TRUE)

jarsToRemove <- c("scala-compiler-.*\\.jar$",
                  "scala-library-.*\\.jar$",
                  "scala-reflect-.*\\.jar$",
                  "scalap-.*\\.jar$",
                  "scala-parser-combinators_.*\\.jar$",
                  "scala-xml_.*\\.jar$")
jars <- jars[!grepl(jars, pattern = paste(jarsToRemove, collapse = "|"))]
knit_engines$set(scala = make_scala_engine(JARs = jars))

First look at the data.

The code below creates a SparkSession which is required to perform all the Spark operations. Then, the csv is loaded from the local file. I use spark_partition_id to add a new column with the id of the partition in which the row is located. It’s a useful function if you want to check how Spark partitioned your data.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.collect_list
import org.apache.spark.sql.functions._

val spark = SparkSession.builder.master("local[*]")
              .appName("Simple Application")
              .getOrCreate()

val df = spark.read
         .format("csv")
         .option("header", "true")
         .load("file:///tmp/iris.csv")
         .withColumn("pid", spark_partition_id())

Now I can examine the data by calling the show function.

df.show
## +-------+----------+---+
## |sep_len|   species|pid|
## +-------+----------+---+
## |    5.1|    setosa|  0|
## |    4.9|    setosa|  0|
## |    4.7|    setosa|  0|
## |      7|versicolor|  0|
## |    6.4|versicolor|  0|
## |    6.9|versicolor|  0|
## |    6.3| virginica|  0|
## |    5.8| virginica|  0|
## |    7.1| virginica|  0|
## +-------+----------+---+

It seems that Spark put all the data into just one partition. If I want to take advantage of the parallel nature of Spark one partition is a bit problematic. Spark uses partitions to spread the data between workers, and one partition means that only one job will be scheduled.

Repartition the DataFrame.

To spread the data into multiple partitions the repartition method should be used. The first argument specifies the number of resulting partitions. Let’s run the code, and inspect the result:

val dfRep = spark.read
         .format("csv")
         .option("header", "true")
         .load("file:///tmp/iris.csv")
         .repartition(3)
         .withColumn("pid", spark_partition_id())
         
dfRep.show
## +-------+----------+---+
## |sep_len|   species|pid|
## +-------+----------+---+
## |    6.3| virginica|  0|
## |    6.9|versicolor|  0|
## |    5.1|    setosa|  0|
## |    7.1| virginica|  1|
## |    5.8| virginica|  1|
## |    4.9|    setosa|  1|
## |    4.7|    setosa|  2|
## |    6.4|versicolor|  2|
## |      7|versicolor|  2|
## +-------+----------+---+

Spark created three partitions. However, the data related to species are spread between partitions so that some operations will require to reshuffle the data. For example in the case of aggregation functions, Spark will compute the partial results for each partition and then combine all of them. Let’s use the collect_list function on pids to gather the information of the beginning location of each row, grouped by species.

val dfRepAgg = dfRep
  .groupBy("species")
  .agg(collect_list(col("pid")).alias("pids"))
  .withColumn("pid", spark_partition_id())
  
dfRepAgg.show
## +----------+---------+---+
## |   species|     pids|pid|
## +----------+---------+---+
## | virginica|[0, 1, 2]| 41|
## |versicolor|[0, 1, 2]|104|
## |    setosa|[0, 1, 2]|158|
## +----------+---------+---+

In this case, Spark had to perform a lot of operations, because after collecting the values within each partition, he had to combine the partial results to form the final vector. But how do I know what Spark did? There’s an explain method which shows all the step which will be performed to get the final result.

dfRepAgg.explain
## == Physical Plan ==
## *(3) Project [species#42, pids#67, SPARK_PARTITION_ID() AS pid#70]
## +- ObjectHashAggregate(keys=[species#42], functions=[collect_list(pid#46, 0, 0)])
##    +- Exchange hashpartitioning(species#42, 200)
##       +- ObjectHashAggregate(keys=[species#42], functions=[partial_collect_list(pid#46, 0, 0)])
##          +- *(2) Project [species#42, SPARK_PARTITION_ID() AS pid#46]
##             +- Exchange RoundRobinPartitioning(3)
##                +- *(1) FileScan csv [species#42] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/tmp/iris.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<species:string>

Let’s break the output into pieces (it should be read from the bottom):

  • FileScan - read csv from disk.
  • Exchange RoundRobinPartitioning(3) - split the data into three partitions using RoundRobinPartitioning (you can google spark partitioners to get more information). This is the operation defined for dfRep.
  • ObjectHashAggregate(keys=[species#42], functions=[partial_collect_list(pid#46, 0, 0)]) - compute partial results within each partition.
  • Exchange hashpartitioning(species#42, 200) - repartition the data using hash partitioner, based on the species column. In this step, Spark performs shuffling, and all the partial results from previous step land in the same partition. Note that the data is transferred between nodes. Such transfer might be quite an expensive operation, especially if there’s a lot of partial results.
  • ObjectHashAggregate(keys=[species#42]], functions=[collect_list(pid#46, 0, 0)]) - in this step Spark combines all the partial results into final value.

I skipped the description of the Project steps because they seems to be unimportant in this case.

As we can see, there’s a step in which Spark needs to transfer the data between nodes to merge all the partial results. But we can skip this step, by telling Spark how better partition the source data. In this case, we know in advance that all the data related to each species should be in the same partition. We can utilize this knowledge by calling repartition with more arguments:

val dfRepSpec = spark.read
         .format("csv")
         .option("header", "true")
         .load("file:///tmp/iris.csv")
         .repartition(3, col("species"))
         .withColumn("pid", spark_partition_id())
         
dfRepSpec.show
## +-------+----------+---+
## |sep_len|   species|pid|
## +-------+----------+---+
## |    6.3| virginica|  0|
## |    5.8| virginica|  0|
## |    7.1| virginica|  0|
## |    5.1|    setosa|  1|
## |    4.9|    setosa|  1|
## |    4.7|    setosa|  1|
## |      7|versicolor|  2|
## |    6.4|versicolor|  2|
## |    6.9|versicolor|  2|
## +-------+----------+---+

As you can see, all values that belong to the same species lie in the same partition. Let’s perform the same aggregation operation as before, then examine the result, and the explain output:

val dfRepSpecAgg = dfRepSpec
  .groupBy("species")
  .agg(collect_list(col("pid")).alias("pids"))
  .withColumn("pid", spark_partition_id())

dfRepSpecAgg.show
## +----------+---------+---+
## |   species|     pids|pid|
## +----------+---------+---+
## | virginica|[0, 0, 0]|  0|
## |    setosa|[1, 1, 1]|  1|
## |versicolor|[2, 2, 2]|  2|
## +----------+---------+---+
dfRepSpecAgg.explain
## == Physical Plan ==
## *(3) Project [species#124, pids#149, SPARK_PARTITION_ID() AS pid#152]
## +- ObjectHashAggregate(keys=[species#124], functions=[collect_list(pid#128, 0, 0)])
##    +- ObjectHashAggregate(keys=[species#124], functions=[partial_collect_list(pid#128, 0, 0)])
##       +- *(2) Project [species#124, SPARK_PARTITION_ID() AS pid#128]
##          +- Exchange hashpartitioning(species#124, 3)
##             +- *(1) FileScan csv [species#124] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/tmp/iris.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<species:string>

In this case, there’s no RoundRobinPartitioning or Exchange hashpartitioning between ObjectHashAggregate and ObjectHashAggregate because of all values required to operate lie in the same partitions. It will be much faster than the previous solution because there will be no data transfer between nodes.

We can also check that the results did not change the partitions:

dfRepSpecAgg.withColumn("pid", spark_partition_id()).show
## +----------+---------+---+
## |   species|     pids|pid|
## +----------+---------+---+
## | virginica|[0, 0, 0]|  0|
## |    setosa|[1, 1, 1]|  1|
## |versicolor|[2, 2, 2]|  2|
## +----------+---------+---+

We can use more than one column to repartition the data. However, keep in mind that sometimes it’s unreasonable to keep all values related to given key in the same partition. For example, if we have a binary key, it would lead us to only two partitions killing all the parallel capabilities. Another situation is the skewed distribution when we have a lot of data for a few keys, and just a bunch of observations for others. It might lead to a situation when all the operation for smaller keys will need to wait to the keys with a bigger number of values to process. So there’s no silver bullet, and the type of partitioning largely depends on the specific situation.

Parquet

In the last part of this post, I’ll briefly describe the parqet file, which is a widely used format for storing the data for Spark (it’s much better than csv). It allows saving the partitioned data:

dfRepSpec
  .select("sep_len","species")
  .write
  .partitionBy("species")
  .parquet("/tmp/iris")

It creates a tree structure, in which the directories are used to separate files for different partitions. Let’s take a look for the created parquet file:

tree /tmp/iris
## /tmp/iris
## ├── species=setosa
## │   └── part-00001-ce7c26e1-7865-4b9f-b532-fdb9a30942dc.c000.snappy.parquet
## ├── species=versicolor
## │   └── part-00002-ce7c26e1-7865-4b9f-b532-fdb9a30942dc.c000.snappy.parquet
## ├── species=virginica
## │   └── part-00000-ce7c26e1-7865-4b9f-b532-fdb9a30942dc.c000.snappy.parquet
## └── _SUCCESS
## 
## 3 directories, 4 files

We can load the data, and check that all rows with the same species lie in the same partitions:

val dfPar = spark.read.parquet("/tmp/iris").withColumn("pid", spark_partition_id())
dfPar.show
## +-------+----------+---+
## |sep_len|   species|pid|
## +-------+----------+---+
## |    6.3| virginica|  0|
## |    5.8| virginica|  0|
## |    7.1| virginica|  0|
## |    5.1|    setosa|  1|
## |    4.9|    setosa|  1|
## |    4.7|    setosa|  1|
## |      7|versicolor|  2|
## |    6.4|versicolor|  2|
## |    6.9|versicolor|  2|
## +-------+----------+---+

Then we can perform aggregation:

val dfParAgg = dfPar
  .groupBy("species")
  .agg(collect_list(col("pid")).alias("pids"))
  .withColumn("pid", spark_partition_id())
  
dfParAgg.show
## +----------+---------+---+
## |   species|     pids|pid|
## +----------+---------+---+
## | virginica|[0, 0, 0]| 41|
## |versicolor|[2, 2, 2]|104|
## |    setosa|[1, 1, 1]|158|
## +----------+---------+---+

There’s something strange in the output because the values in pid don’t match the values in pids. It might be a signal that the Spark performed a reshuffling operation, and moved data between nodes. We should check the explain method:

dfParAgg.explain
## == Physical Plan ==
## *(2) Project [species#236, pids#258, SPARK_PARTITION_ID() AS pid#261]
## +- ObjectHashAggregate(keys=[species#236], functions=[collect_list(pid#239, 0, 0)])
##    +- Exchange hashpartitioning(species#236, 200)
##       +- ObjectHashAggregate(keys=[species#236], functions=[partial_collect_list(pid#239, 0, 0)])
##          +- *(1) Project [species#236, SPARK_PARTITION_ID() AS pid#239]
##             +- *(1) FileScan parquet [sep_len#235,species#236] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/tmp/iris], PartitionCount: 3, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<sep_len:string>

There’s an Exchange hashpartitioning(species#236, 200) operation which is a signal that Spark reshuffled the data. It read the data correctly but did not create a proper partitioner, so Spark treated the data like randomly distributed ones. It can be easily solved by adding the repartition method before aggregation.

val dfParAgg2 = dfPar.repartition(3, col("species"))
  .groupBy("species")
  .agg(collect_list(col("pid")).alias("pids"))
  .withColumn("pid", spark_partition_id())
  
dfParAgg2.show
## +----------+---------+---+
## |   species|     pids|pid|
## +----------+---------+---+
## | virginica|[0, 0, 0]|  0|
## |    setosa|[1, 1, 1]|  1|
## |versicolor|[2, 2, 2]|  2|
## +----------+---------+---+

And the `explain:

dfParAgg2.explain
## == Physical Plan ==
## *(2) Project [species#236, pids#308, SPARK_PARTITION_ID() AS pid#311]
## +- ObjectHashAggregate(keys=[species#236], functions=[collect_list(pid#239, 0, 0)])
##    +- ObjectHashAggregate(keys=[species#236], functions=[partial_collect_list(pid#239, 0, 0)])
##       +- Exchange hashpartitioning(species#236, 3)
##          +- *(1) Project [species#236, SPARK_PARTITION_ID() AS pid#239]
##             +- *(1) FileScan parquet [sep_len#235,species#236] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/tmp/iris], PartitionCount: 3, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<sep_len:string>

Note that Exchange hashpartitioning between partial_collect_list and collect_list is now gone, so everything should be alright.

I wrote the number of partitions by hand (3), but usually, it’s not so easy to count them, but you can always use df.rdd.getNumPartitions to get their number. So the final code would be:

val dfParAgg3 = dfPar.repartition(dfPar.rdd.getNumPartitions, col("species"))
  .groupBy("species")
  .agg(collect_list(col("pid")).alias("pids"))
  .withColumn("pid", spark_partition_id())

Summary

In this post, I briefly showed how to repartition the data to avoid exchanging the information between nodes usually is very costly, and has a significant impact on the performance of the Spark application. Of course, I did not cover all possible topics and caveats. For more information you can check: