Using Scala 3 with Spark
by Chris Birchall
- •
- February 08, 2022
- •
- scala• functional programming• scala3• spark
- |
- 17 minutes to read.

Apache Spark is a hugely popular data engineering tool that accounts for a large segment of the Scala community. Every Spark release is tied to a specific Scala version, so a large subset of Scala users have little control over the Scala version they use because it is dictated by Spark.
Scala 2.13 was released in June 2019, but it took more than two years and a huge effort by the Spark maintainers for the first Scala 2.13-compatible Spark release (Spark 3.2.0) to arrive. While developers appreciated how much work went into upgrading Spark to Scala 2.13, it was still a little frustrating to be stuck on an older version of Scala for so long.
If it took two years for Spark to upgrade from Scala 2.12 to 2.13, you might be wondering how long it will take for a Scala 3 version of Spark to arrive. The answer is: it doesn’t matter! We can already use Scala 3 to build Spark applications thanks to the compatibility between Scala 2.13 and Scala 3.
In the remainder of this post, I’d like to demonstrate how to build a Scala 3 application that runs on a Spark 3.2.0 cluster. I’ll start with a small hello-world app, then demonstrate a more realistic app that exercises more of Spark’s features. My hope is that this will dispel some myths and reassure people that putting a Scala 3 Spark app into production is a reasonable thing to do.
All the code and scripts for this post are available on GitHub.
Hello, World!
Without further ado, let’s write a Spark app in Scala 3.
sbt configuration
We’ll start with the sbt configuration. The build.sbt
file is very short, but
requires a little explanation:
scalaVersion := "3.1.1"
libraryDependencies ++= Seq(
("org.apache.spark" %% "spark-sql" % "3.2.0" % "provided").cross(CrossVersion.for3Use2_13)
)
// include the 'provided' Spark dependency on the classpath for `sbt run`
Compile / run := Defaults.runTask(Compile / fullClasspath, Compile / run / mainClass, Compile / run / runner).evaluated
-
We set the Scala version to 3.1.1, the latest Scala release at the time of writing. Woohoo, Scala 3! That means we can use the new hipster syntax and get our hands on shiny new features like enums, extension methods, and opaque types.
- We add a dependency on
spark-sql
v3.2.0 so we can use the Spark API in our code- We mark it as
provided
because we are going to package our app as an uber-jar for deployment to a Spark cluster. We don’t want to include the Spark API in our deployment package because it’s provided by the Spark runtime. - Because there is no Scala 3 version of
spark-sql
available, we useCrossVersion.for3Use2_13
to tell sbt we want the Scala 2.13 version of this library
- We mark it as
- Finally, we add a one-line incantation to tell sbt to add all “provided”
dependencies to the classpath when we execute the
run
task. Exercising our code usingsbt run
will give us a much faster testing cycle than repackaging the app and deploying it to a Spark cluster after every change.
A warning about dependencies
While Scala 3 and 2.13 can magically interoperate, there is one restriction
worth being aware of: if you accidentally end up with both the _2.13
and _3
versions of a given library on your classpath, you’re headed directly to
Dependency Hell. In fact, sbt will detect this case during dependency resolution
and fail loudly to save you from yourself.
This doesn’t matter for our hello-world app, as we have no dependencies other
than spark-sql
, but a real Spark job is likely to have dependencies on other
libraries. Every time you add a dependency, you’ll need to check its graph of
transitive dependencies and use CrossVersion.for3Use2_13
where necessary.
Luckily, Spark has only a handful of Scala dependencies:
com.fasterxml.jackson.module::jackson-module-scala
com.twitter::chill
org.json4s::json4s-jackson
org.scala-lang.modules::scala-parallel-collections
org.scala-lang.modules::scala-parser-combinators
org.scala-lang.modules::scala-xml
Code
Our hello-world is just a few lines of Scala:
// src/main/scala/helloworld/HelloWorld.scala
package helloworld
import org.apache.spark.sql.SparkSession
object HelloWorld:
@main def run =
val spark = SparkSession.builder
.appName("HelloWorld")
.master(sys.env.getOrElse("SPARK_MASTER_URL", "local[*]"))
.getOrCreate() // 1
import spark.implicits._ // 2
val df = List("hello", "world").toDF // 3
df.show() // 4
spark.stop
If you’re familiar with Spark, there should be no surprises here:
- Create a Spark session, running it on a local standalone Spark cluster by default.
- Import some implicits so we can call
.toDF
in step 3. - Create a one-column DataFrame with two rows.
- Print the contents of the DataFrame to stdout.
When we run this code using sbt run
, we see a load of Spark logs and our
DataFrame:
+-----+
|value|
+-----+
|hello|
|world|
+-----+
Success! We just ran our first Scala 3 Spark app.
You may have noticed that the name of the name class is helloworld.run
, even
though our @main
-annotated run
method is inside the HelloWorld
object.
This behavior is
documented
in the Scala 3 docs, but I found it a little surprising.
Running on a real Spark cluster
Running inside sbt is all well and good, but we need to see whether this will work when deployed to an actual Spark cluster.
The GitHub repo for this post contains a Dockerfile
,
docker-compose.yml
and scripts for building and running a 3-node Spark cluster
in Docker. I won’t go into all the details of the Docker image, but there is one
important point to mention: there are multiple Spark 3.2.0
binaries available for
download and you need to pick the right one when installing Spark. Even though
Spark 3.2.0 supports Scala 2.13, the default Scala version is still 2.12, so you
need to pick the one with the _scala2.13
suffix.
(Similarly, if you use a managed Spark platform such as Databricks or AWS EMR, you’ll need to choose a Spark 3.2.0 + Scala 2.13 runtime. Neither Databricks nor EMR offer such a runtime just yet.)
Running ./start-spark-cluster.sh
will build the Docker image and start a Spark
cluster containing a master node and two worker nodes.
If you run ./open-spark-UIs.sh
, you should be able to see the Spark UI in your
browser:
We are now ready to package our Spark application and submit it to the cluster for execution.
Running sbt assembly
will build an uber-jar containing our application code
and the Scala standard library. (Actually it contains two Scala standard
libraries: one for Scala 2.13.6 and one for 3.1.1. I don’t fully understand
this, but it seems to work!) Remember, we marked the Spark dependency as
“provided”, so it is not included in the uber-jar.
You can now run the ./run-hello-world.sh
script to execute our app on the
Spark cluster. You should see similar terminal output to what you saw when you
ran the app inside sbt.
If you refresh the Spark master UI, you should see that the app completed successfully:
Nice!
A more realistic application
Now we’ve got a hello-world under our belts, let’s build a more complex application that exercises more of Spark’s features and looks a bit more like the kind of Spark job you might write in the real world.
Spark features of interest
One Spark feature worth verifying is typed Datasets. It’s often preferable to
work with a Dataset[MyCaseClass]
rather than a raw DataFrame
, so that
Scala’s type system can help us avoid bugs. But Spark’s mechanism for converting
between case classes and its internal data representation relies on runtime
reflection - is that going to work in Scala 3? We’ll find out soon.
Another important Spark feature we should exercise is user-defined functions (UDFs), as it requires Spark to serialize our Scala code and execute it on the worker nodes. Intuitively that seems like something else that might break when we use Scala 3.
Application overview
We’re going to use Spark to solve the classic Traveling Salesman Problem (TSP). We want to find the shortest route that starts at any of a set of N cities, visits all the other cities exactly once, and finally comes back to where it started.
Solving the TSP in an efficient way is far beyond the scope of this post, so our job will use a very naive brute-force approach. However, we are at least able to distribute the computation across the Spark cluster, which should speed things up slightly.
(Aside: for an entertaining tour of the history of TSP research and an overview of the state of the art, check out this talk by Prof. William Cook of the University of Waterloo.)
Our application will work as follows:
- Load a list of cities (names and locations) from a configuration file.
- Build a list of every possible journey through those cities.
- Split the journeys into legs, from one city to the next.
- Calculate the distance of each leg.
- Sum those distances to give the total distance of each journey.
- Print out the route and distance of the shortest journey.
This will give us a chance to verify a few Spark features, including the two mentioned above.
Data model
Here are the case classes we’ll use to model our data as it goes through a few transformations:
// A city with a name and a lat/long coordinate
case class City(name: String, lat: Double, lon: Double)
// A stop at a given city as part of a given journey
case class JourneyStop(journeyID: String, index: Int, city: City)
// A leg of a journey, from one city to the next
case class JourneyLeg(journeyID: String, index: Int, city: City, previousCity: City)
// The distance in km of a given journey leg
case class JourneyLegDistance(journeyID: String, index: Int, distance: Double)
We will use typed Datasets where possible so we can work with these case
classes in Spark instead of dealing with raw Row
s.
Loading the data into Spark
Skipping the parts where we load the cities from configuration and enumerate all
possible journeys, let’s just assume we have a list of JourneyStop
s.
We’d like to load it into a Dataset
so we can work with it in Spark:
val journeyStops: List[JourneyStop] = ???
val journeyStopsDs: Dataset[JourneyStop] = spark.createDataset(journeyStops)
Unfortunately this doesn’t compile. We’ve hit our first Scala 3 stumbling block!
[error] -- Error: /Users/chris/code/spark-scala3-example/src/main/scala/tsp/TravellingSalesman.scala:61:80
[error] 61 | val journeyStopsDs: Dataset[JourneyStop] = spark.createDataset(journeyStops)
[error] | ^
[error] |Unable to find encoder for type tsp.TravellingSalesman.JourneyStop. An implicit Encoder[tsp.TravellingSalesman.JourneyStop] is needed to store tsp.TravellingSalesman.JourneyStop instances in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases..
[error] |I found:
[error] |
[error] | spark.implicits.newProductEncoder[tsp.TravellingSalesman.JourneyStop](
[error] | /* missing */
[error] | summon[reflect.runtime.universe.TypeTag[tsp.TravellingSalesman.JourneyStop]]
[error] | )
[error] |
[error] |But no implicit values were found that match type reflect.runtime.universe.TypeTag[tsp.TravellingSalesman.JourneyStop].
The compiler message is quite helpful: implicit resolution errors are much more
informative in Scala 3 than they were in Scala 2. Spark is trying to derive an
encoder to convert our case class into its internal data representation, but
it’s failing because it can’t find a scala.reflect.runtime.universe.TypeTag
.
Scala 3 does not support runtime reflection, so it’s not surprising that this
failed.
Luckily, Vincenzo Bazzucchi at the Scala Center has written a handy library to
take care of this problem. It’s called
spark-scala3 and it provides
generic derivation of Encoder
instances for case classes using Scala 3’s new
metaprogramming features instead of runtime reflection. After adding that
library as a dependency, and adding the necessary import, our code compiles.
That takes care of the first of the two Spark features of interest: typed Datasets.
We can now perform our first transformation, turning journey stops into journey legs:
val journeyLegs: Dataset[JourneyLeg] = journeyStopsDs
.withColumn(
"previousCity",
lag("city", 1).over(Window.partitionBy("journeyID").orderBy("index"))
)
.as[JourneyLeg]
There’s not much to say about this, except that we go from a typed Dataset
to a
Dataframe
and then back to a typed Dataset
without any problems.
User Defined Functions
The next step of the job is to calculate the distance of each journey leg. We do this by rolling our own implementation of the Haversine formula. I’ll omit the details of the implementation.
As we have a custom function we need to apply to the data, this is a good opportunity to test UDFs. In general, it’s best practice to minimize use of UDFs and use only Spark’s built-in transformation operators, because UDFs are a black box to the Catalyst optimizer. But this distance calculation would be pretty arduous to write with the built-in operators.
Normally to create a UDF with Scala, you’d write a plain Scala function and then
wrap it in Spark’s udf(...)
helper function to lift it into a UDF:
def addAndDouble(x: Int, y: Int): Int = (x + y) * 2
val addAndDoubleUDF = udf(addAndDouble)
Unfortunately that doesn’t work with Scala 3 because it relies on TypeTag
,
just like the automatic derivation of Encoder
instances for case classes.
I was able to use the Java API to build a UDF, but it’s pretty unpleasant:
val haversineJavaUDF: UDF4[JDouble, JDouble, JDouble, JDouble, JDouble] =
new UDF4[JDouble, JDouble, JDouble, JDouble, JDouble] {
def call(lat1: JDouble, lon1: JDouble, lat2: JDouble, lon2: JDouble): JDouble =
JDouble.valueOf(Haversine.distance(lat1, lon1, lat2, lon2))
}
val haversineUDF: UserDefinedFunction = udf(haversineJavaUDF, DataTypes.DoubleType)
Once that’s done, we can use the UDF to add a distance
column to our Dataset:
val journeyLegDistances: Dataset[JourneyLegDistance] = journeyLegs
.withColumn(
"distance",
when(isnull($"previousCity"), 0.0)
.otherwise(haversineUDF($"city.lat", $"city.lon", $"previousCity.lat", $"previousCity.lon"))
)
.drop("city", "previousCity")
.as[JourneyLegDistance]
However, in our case, we don’t actually need to use a UDF here. We can achieve
the same thing more simply and safely with map
:
val journeyLegDistancesWithoutUDF: Dataset[JourneyLegDistance] = journeyLegs.map { leg =>
val distance = Option(leg.previousCity) match {
case Some(City(_, prevLat, prevLon)) =>
Haversine.distance(prevLat, prevLon, leg.city.lat, leg.city.lon)
case None =>
0.0
}
JourneyLegDistance(leg.journeyID, leg.index, distance)
}
Aggregation
Now the only thing left to do is a simple aggregation to find the total distance of each journey. We order by total distance and pick the first row, i.e., the shortest journey:
val journeyDistances: Dataset[(String, Double)] = journeyLegDistancesWithoutUDF
.groupByKey(_.journeyID)
.agg(typed.sum[JourneyLegDistance](_.distance).name("totalDistance"))
.orderBy($"totalDistance")
val (shortestJourney, shortestDistance) = journeyDistances.take(1).head
Result
If you run the TSP job, either via sbt run
or on the Spark cluster using
./run-travelling-salesman.sh
, you should get the following result:
The shortest journey is
New York
->Chicago
->San Jose
->Los Angeles
->San Diego
->Phoenix
->San Antonio
->Houston
->Dallas
->Philadelphia
->New York
with a total distance of 9477.70 km
And here it is on a map:
Google Maps even agrees with our distance calculation to within about 50km, which is a nice sanity check.
Conclusion
We’ve used the power of Spark and Scala 3 to plan an epic (and optimal) road trip around America, and validated a few Spark features along the way.
Spark mostly works fine with Scala 3. Just remember, whenever you see a method
in the Spark API that requires a TypeTag
, you’ll need to be prepared to find a
workaround.
One final reminder: all the code and scripts for this post are available on GitHub.