Blog

Using Scala 3 with Spark

07 Mar, 2023
Xebia Background Header Wave

This article was originally published at 47deg.com on February 8, 2022.

Apache Spark is a hugely popular data engineering tool that accounts for a large segment of the Scala community. Every Spark release is tied to a specific Scala version, so a large subset of Scala users have little control over the Scala version they use because Spark dictates it.

Scala 2.13 was released in June 2019, but it took more than two years and a huge effort by the Spark maintainers for the first Scala 2.13-compatible Spark release (Spark 3.2.0) to arrive. While developers appreciated how much work went into upgrading Spark to Scala 2.13, it was still a little frustrating to be stuck on an older version of Scala for so long.

If it took two years for Spark to upgrade from Scala 2.12 to 2.13, you might be wondering how long it will take for a Scala 3 version of Spark to arrive. The answer is: it doesn’t matter! We can already use Scala 3 to build Spark applications thanks to the compatibility between Scala 2.13 and Scala 3.

In the remainder of this post, I’d like to demonstrate how to build a Scala 3 application that runs on a Spark 3.2.0 cluster. I’ll start with a small hello-world app, then demonstrate a more realistic app that exercises more of Spark’s features. I hope this will dispel some myths and reassure people that putting a Scala 3 Spark app into production is reasonable.

All the code and scripts for this post are available on GitHub.

Hello, World!

Without further ado, let’s write a Spark app in Scala 3.

sbt configuration

We’ll start with the sbt configuration. The build.sbt file is very short, but requires a little explanation:

scalaVersion := "3.1.1"

libraryDependencies ++= Seq(
  ("org.apache.spark" %% "spark-sql" % "3.2.0" % "provided").cross(CrossVersion.for3Use2_13)
)

// include the 'provided' Spark dependency on the classpath for <code>sbt run
Compile / run := Defaults.runTask(Compile / fullClasspath, Compile / run / mainClass, Compile / run / runner).evaluated
  • We set the Scala version to 3.1.1, the latest Scala release at the time of writing. Woohoo, Scala 3! That means we can use the new hipster syntax and get our hands on shiny new features like enums, extension methods, and opaque types.
  • We add a dependency on spark-sql v3.2.0 so we can use the Spark API in our code
    • We mark it as provided because we will package our app as an uber-jar for deployment to a Spark cluster. We don’t want to include the Spark API in our deployment package because the Spark runtime provides it.
    • Because there is no Scala 3 version of spark-sql available, we use CrossVersion.for3Use2_13 to tell sbt we want the Scala 2.13 version of this library
  • Finally, we add a one-line incantation to tell sbt to add all “provided” dependencies to the classpath when we execute the run task. Exercising our code using sbt run will give us a much faster testing cycle than repackaging the app and deploying it to a Spark cluster after every change.

A warning about dependencies

While Scala 3 and 2.13 can magically interoperate, there is one restriction worth being aware of: if you accidentally end up with both the _2.13 and _3 versions of a given library on your classpath, you’re headed directly to Dependency Hell. In fact, sbt will detect this case during dependency resolution and fail loudly to save you from yourself.

This doesn’t matter for our hello-world app, as we have no dependencies other than spark-sql, but a real Spark job is likely to have dependencies on other libraries. Every time you add a dependency, you’ll need to check its graph of transitive dependencies and use CrossVersion.for3Use2_13 where necessary.

Luckily, Spark has only a handful of Scala dependencies:

  • com.fasterxml.jackson.module::jackson-module-scala
  • com.twitter::chill
  • org.json4s::json4s-jackson
  • org.scala-lang.modules::scala-parallel-collections
  • org.scala-lang.modules::scala-parser-combinators
  • org.scala-lang.modules::scala-xml

Code

Our hello-world is just a few lines of Scala:

// src/main/scala/helloworld/HelloWorld.scala
package helloworld

import org.apache.spark.sql.SparkSession

object HelloWorld:

  @main def run =
    val spark = SparkSession.builder
      .appName("HelloWorld")
      .master(sys.env.getOrElse("SPARK_MASTER_URL", "local[*]"))
      .getOrCreate()           // 1
    import spark.implicits._   // 2

    val df = List("hello", "world").toDF  // 3
    df.show()                             // 4

    spark.stop

If you’re familiar with Spark, there should be no surprises here:

  1. Create a Spark session, running it on a local standalone Spark cluster by default.
  2. Import some implicits so we can call .toDF in step 3.
  3. Create a one-column DataFrame with two rows.
  4. Print the contents of the DataFrame to stdout.

When we run this code using sbt run, we see a load of Spark logs and our DataFrame:

+-----+
|value|
+-----+
|hello|
|world|
+-----+

Success! We just ran our first Scala 3 Spark app.

You may have noticed that the name of the name class is helloworld.run, even though our a>@main</code/a>code>@main</code</a-annotated run method is inside the HelloWorld object. This behavior is documented in the Scala 3 docs, but I found it a little surprising.

Running on a real Spark cluster

Running inside sbt is all well and good, but we need to see whether this will work when deployed to an actual Spark cluster.

The GitHub repo for this post contains a Dockerfiledocker-compose.yml and scripts for building and running a 3-node Spark cluster in Docker. I won’t go into all the details of the Docker image, but there is one important point to mention: multiple Spark 3.2.0 binaries are available for download and you need to pick the right one when installing Spark. Even though Spark 3.2.0 supports Scala 2.13, the default Scala version is still 2.12, so you need to pick the one with the _scala2.13 suffix.

(Similarly, if you use a managed Spark platform such as Databricks or AWS EMR, you’ll need to choose a Spark 3.2.0 + Scala 2.13 runtime. Neither Databricks nor EMR offer such a runtime just yet.)

Running ./start-spark-cluster.sh will build the Docker image and start a Spark cluster containing a master node and two worker nodes.

If you run ./open-spark-UIs.sh, you should be able to see the Spark UI in your browser:

We are now ready to package our Spark application and submit it to the cluster for execution.

Running sbt assembly will build an uber-jar containing our application code and the Scala standard library. (Actually, it contains two Scala standard libraries: one for Scala 2.13.6 and one for 3.1.1. I don’t fully understand this, but it seems to work!) Remember, we marked the Spark dependency as “provided,” so it is not included in the uber-jar.

You can now run the ./run-hello-world.sh script to execute our app on the Spark cluster. You should see similar terminal output to what you saw when you ran the app inside sbt.

If you refresh the Spark master UI, you should see that the app has been completed successfully:

Nice!

A more realistic application

Now we’ve got a hello-world under our belts, let’s build a more complex application that exercises more of Spark’s features and looks a bit more like the kind of Spark job you might write in the real world.

Spark features of interest

One Spark feature worth verifying is typed Datasets. It’s often preferable to work with a Dataset[MyCaseClass] rather than a raw DataFrame, so that Scala’s type system can help us avoid bugs. But Spark’s mechanism for converting between case classes and its internal data representation relies on runtime reflection – is that going to work in Scala 3? We’ll find out soon.

Another important Spark feature we should exercise is user-defined functions (UDFs), as it requires Spark to serialize our Scala code and execute it on the worker nodes. Intuitively that seems like something else that might break when we use Scala 3.

Application overview

We’re going to use Spark to solve the classic Traveling Salesman Problem (TSP). We want to find the shortest route that starts at any set of N cities, visits all the other cities exactly once, and finally returns to where it started.

Solving the TSP efficiently is far beyond this post’s scope, so our job will use a very naive brute-force approach. However, we can at least distribute the computation across the Spark cluster, which should speed things up slightly.

(Aside: for an entertaining tour of the history of TSP research and an overview of state of the art, check out this talk by Prof. William Cook of the University of Waterloo.)

Our application will work as follows:

  1. Load a list of cities (names and locations) from a configuration file.
  2. Build a list of every possible journey through those cities.
  3. Split the journeys into legs, from one city to the next.
  4. Calculate the distance of each leg.
  5. Sum those distances to give the total distance of each journey.
  6. Print out the route and distance of the shortest journey.

This will give us a chance to verify a few Spark features, including the two mentioned above.

Data model

Here are the case classes we’ll use to model our data as it goes through a few transformations:

// A city with a name and a lat/long coordinate
case class City(name: String, lat: Double, lon: Double)

// A stop at a given city as part of a given journey
case class JourneyStop(journeyID: String, index: Int, city: City)

// A leg of a journey, from one city to the next
case class JourneyLeg(journeyID: String, index: Int, city: City, previousCity: City)

// The distance in km of a given journey leg
case class JourneyLegDistance(journeyID: String, index: Int, distance: Double)

We will use typed Datasets where possible so we can work with these case classes in Spark instead of dealing with raw Rows.

Loading the data into Spark

Skipping the parts where we load the cities from configuration and enumerate all possible journeys, let’s just assume we have a list of JourneyStops.

We’d like to load it into a Dataset so we can work with it in Spark:

val journeyStops: List[JourneyStop] = ???
val journeyStopsDs: Dataset[JourneyStop] = spark.createDataset(journeyStops)

Unfortunately, this doesn’t compile. We’ve hit our first Scala 3 stumbling block!

[error] -- Error: /Users/chris/code/spark-scala3-example/src/main/scala/tsp/TravellingSalesman.scala:61:80
[error] 61 |    val journeyStopsDs: Dataset[JourneyStop] = spark.createDataset(journeyStops)
[error]    |                                                                                ^
[error]    |Unable to find encoder for type tsp.TravellingSalesman.JourneyStop. An implicit Encoder[tsp.TravellingSalesman.JourneyStop] is needed to store tsp.TravellingSalesman.JourneyStop instances in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases..
[error]    |I found:
[error]    |
[error]    |    spark.implicits.newProductEncoder[tsp.TravellingSalesman.JourneyStop](
[error]    |      /* missing */
[error]    |        summon[reflect.runtime.universe.TypeTag[tsp.TravellingSalesman.JourneyStop]]
[error]    |    )
[error]    |
[error]    |But no implicit values were found that match type reflect.runtime.universe.TypeTag[tsp.TravellingSalesman.JourneyStop].

The compiler message is quite helpful: implicit resolution errors are much more informative in Scala 3 than they were in Scala 2. Spark is trying to derive an encoder to convert our case class into its internal data representation, but it’s failing because it can’t find a scala.reflect.runtime.universe.TypeTag. Scala 3 does not support runtime reflection, so it’s not surprising that this failed.

Luckily, Vincenzo Bazzucchi at the Scala Center has written a handy library to take care of this problem. It’s called spark-scala3 and it provides generic derivation of Encoder instances for case classes using Scala 3’s new metaprogramming features instead of runtime reflection. After adding that library as a dependency, and adding the necessary import, our code compiles.

That takes care of the first of the two Spark features of interest: typed Datasets.

We can now perform our first transformation, turning journey stops into journey legs:

val journeyLegs: Dataset[JourneyLeg] = journeyStopsDs
  .withColumn(
    "previousCity",
    lag("city", 1).over(Window.partitionBy("journeyID").orderBy("index"))
  )
  .as[JourneyLeg]

There’s not much to say about this, except that we go from a typed Dataset to a Dataframe and then back to a typed Dataset without any problems.

User Defined Functions

The next step of the job is to calculate the distance of each journey leg. We do this by rolling our own implementation of the Haversine formula. I’ll omit the details of the implementation.

As we have a custom function we need to apply to the data, this is a good opportunity to test UDFs. In general, it’s best practice to minimize use of UDFs and use only Spark’s built-in transformation operators, because UDFs are a black box to the Catalyst optimizer. But this distance calculation would be pretty arduous to write with the built-in operators.

Normally to create a UDF with Scala, you’d write a plain Scala function and then wrap it in Spark’s udf(...) helper function to lift it into a UDF:

def addAndDouble(x: Int, y: Int): Int = (x + y) * 2

val addAndDoubleUDF = udf(addAndDouble)

Unfortunately, that doesn’t work with Scala 3 because it relies on TypeTag, just like the automatic derivation of Encoder instances for case classes.

I was able to use the Java API to build a UDF, but it’s pretty unpleasant:

val haversineJavaUDF: UDF4[JDouble, JDouble, JDouble, JDouble, JDouble] =
  new UDF4[JDouble, JDouble, JDouble, JDouble, JDouble] {
    def call(lat1: JDouble, lon1: JDouble, lat2: JDouble, lon2: JDouble): JDouble =
      JDouble.valueOf(Haversine.distance(lat1, lon1, lat2, lon2))
  }
val haversineUDF: UserDefinedFunction = udf(haversineJavaUDF, DataTypes.DoubleType)

Once that’s done, we can use the UDF to add a distance column to our Dataset:

val journeyLegDistances: Dataset[JourneyLegDistance] = journeyLegs
  .withColumn(
    "distance",
    when(isnull($"previousCity"), 0.0)
      .otherwise(haversineUDF($"city.lat", $"city.lon", $"previousCity.lat", $"previousCity.lon"))
  )
  .drop("city", "previousCity")
  .as[JourneyLegDistance]

However, in our case, we don’t actually need to use a UDF here. We can achieve the same thing more simply and safely with map:

val journeyLegDistancesWithoutUDF: Dataset[JourneyLegDistance] = journeyLegs.map { leg =>
  val distance = Option(leg.previousCity) match {
    case Some(City(_, prevLat, prevLon)) =>
      Haversine.distance(prevLat, prevLon, leg.city.lat, leg.city.lon)
    case None =>
      0.0
  }
  JourneyLegDistance(leg.journeyID, leg.index, distance)
}

Aggregation

Now the only thing left to do is a simple aggregation to find the total distance of each journey. We order by total distance and pick the first row, i.e., the shortest journey:

val journeyDistances: Dataset[(String, Double)] = journeyLegDistancesWithoutUDF
  .groupByKey(_.journeyID)
  .agg(typed.sum[JourneyLegDistance](_.distance).name("totalDistance"))
  .orderBy($"totalDistance")

val (shortestJourney, shortestDistance) = journeyDistances.take(1).head

Result

If you run the TSP job, either via sbt run or on the Spark cluster using ./run-travelling-salesman.sh, you should get the following result:

The shortest journey is
New York
->Chicago
->San Jose
->Los Angeles
->San Diego
->Phoenix
->San Antonio
->Houston
->Dallas
->Philadelphia
->New York
with a total distance of 9477.70 km

And here it is on a map:

Google Maps even agrees with our distance calculation to within about 50km, which is a nice sanity check.

Conclusion

We’ve used the power of Spark and Scala 3 to plan an epic (and optimal) road trip around America and validated a few Spark features along the way.

Spark mostly works fine with Scala 3. Just remember, whenever you see a method in the Spark API that requires a TypeTag, you’ll need to be prepared to find a workaround.

One final reminder: all the code and scripts for this post are available on GitHub.

Chris Birchall
Chris is a Principal Software Developer at Xebia Functional. His technical interests include distributed systems, functional domain modelling, metaprogramming and property-based testing.
Questions?

Get in touch with us to learn more about the subject and related solutions

Explore related posts