Spark Transformations example (Part 1)

Updated: Oct 25, 2019

Apache Spark transformations like Spark reduceByKey, groupByKey, mapPartitions, map etc are very widely used. Apart from these transformations there are several others, I will explain each of them with examples.

But before I proceed with Spark transformation examples, if you are new to Spark and Scala I would highly encourage you to go through this post - Spark RDD, Transformation and Actions example.

Main menu: Spark Scala Tutorial

We will be using lambda functions to pass through most of these Spark transformations. So those who are new to Scala should have basic understanding of lambda functions.

Lambda Functions

In brief, lambda functions are like normal functions except the fact that they are more convenient when we have to use functions just in one place so that you don't need to worry about defining functions separately.

For example, if you want to double the number you can simply write; x => x + x like you do in Python and other languages. Syntax in Scala would be like this,

scala> val lfunc = (x:Int) => x + x

lfunc: Int => Int = <function1>

scala> lfunc(3)

res0: Int = 6


Sample Data

We will be using "Where is the Mount Everest?" text data which we created in earlier post (SparkContext and text files). I just picked some random data to go through these examples.

Where is Mount Everest? (MountEverest.txt)

Mount Everest (Nepali: Sagarmatha सगरमाथा; Tibetan: Chomolungma ཇོ་མོ་གླང་མ; Chinese Zhumulangma 珠穆朗玛) is Earth's highest mountain above sea level, located in the Mahalangur Himal sub-range of the Himalayas. The international border between Nepal (Province No. 1) and China (Tibet Autonomous Region) runs across its summit point. - Reference Wikipedia

scala> val mountEverest = sc.textFile("/Users/Rajput/Documents/testdata/MountEverest.txt")

mountEverestRDD: org.apache.spark.rdd.RDD[String] = /Users/Rajput/Documents/testdata/MountEverest.txt MapPartitionsRDD[1] at textFile at <console>:24

Spark Transformations

I encourage you all to run these examples on Spark-shell side-by-side.


This transformation redistributes the data after passing each element through func.

For example, if you want to split the Mount Everest text into individual words, you just need to pass this lambda func x => x.split(" ") and it will create a new RDD as shown below.

What is this func doing? It's just reading each element and splitting on the basis of space character.

scala> val words = => x.split(" "))

words: org.apache.spark.rdd.RDD[Array[String]] = MapPartitionsRDD[3] at map at <console>:25

scala> words.collect()

res1: Array[Array[String]] = Array(Array(Mount, Everest, (Nepali:, Sagarmatha, सगरमाथा;, Tibetan:, Chomolungma, ཇོ་མོ་གླང་མ;, Chinese, Zhumulangma, 珠穆朗玛), is, Earth's, highest, mountain, above, sea, level,, located, in, the, Mahalangur, Himal, sub-range, of, the, Himalayas., The, international, border, between, Nepal, (Province, No., 1), and, China, (Tibet, Autonomous, Region), runs, across, its, summit, point.))

Don't worry about collect() action, it's very basic Spark action which is used to return all the element. Now, suppose you want to get the word count in this text file, you can do something like this - first split the file and then get the length or size.

scala> => x.split(" ").length).collect()

res6: Array[Int] = Array(45)

scala> => x.split(" ").size).collect()

res7: Array[Int] = Array(45)

Lets say you want to get total number of characters in file, you can do this.

scala> => x.length).collect()

res5: Array[Int] = Array(329)

Making all text upper case, you can do it like this.

scala> => x.toUpperCase()).collect()



This is also similar to map, except the fact that it gives you more flattened output. For example,

scala> val rdd = sc.parallelize(Seq("Where is Mount Everest","Himalayas India"))

rdd: org.apache.spark.rdd.RDD[String] = ParallelCollectionRDD[22] at parallelize at <console>:24

scala> rdd.collect

res26: Array[String] = Array(Where is Mount Everest, Himalayas India)

scala> => x.split(" ")).collect

res21: Array[Array[String]] = Array(Array(Where, is, Mount, Everest), Array(Himalayas, India))

scala> rdd.flatMap(x => x.split(" ")).collect

res23: Array[String] = Array(Where, is, Mount, Everest, Himalayas, India)

In above case we have two elements in rdd - "Where is Mount Everest" and second "Himalayas India".

When map() transformation is applied, it returned array of array string Array[Array[String]]. It has basically two separate array of strings within an array. So for each element we got one output (1st element => 1 element (Where, is, Mount, Everest), 2nd element => 1 element (Himalayas, India)). And those individual elements are collection of words separated by comma.

But if you see flatMap(), output is flattened to single array of string Array[String]. Thus flatMap() is similar to map, but each input item is mapped to 0 or more output items (1st element => 4 elements, 2nd element => 2 elements).

This will give you clear picture,

scala> => x.split(" ")).count()

res24: Long = 2

scala> rdd.flatMap(x => x.split(" ")).count()

res25: Long = 6

map() => [Where is Mount Everest, Himalayas India] => [[Where, is, Mount, Everest],[Himalayas, India]]

flatMap() => [Where is Mount Everest, Himalayas India] => [Where, is, Mount, Everest, Himalayas, India]


As name says it is used to filter elements same like where clause in SQL and it is case sensitive. For example,

// returns one element which contains match

scala> rdd.filter(x=>x.contains("Himalayas")).collect

res31: Array[String] = Array(Himalayas India)

// No match

scala> rdd.filter(x=>x.contains("Dataneb")).collect

res32: Array[String] = Array()

// Case sensitive

scala> rdd.filter(x=>x.contains("himalayas")).collect

res33: Array[String] = Array()

scala> rdd.filter(x=>x.toLowerCase.contains("himalayas")).collect

res37: Array[String] = Array(Himalayas India)


Similar to map() transformation but in this case function runs separately on each partition (block) of RDD unlike map() where it was running on each element of partition. Hence mapPartitions are also useful when you are looking for performance gain (calls your function once/partition not once/element).

  • Suppose you have elements from 1 to 100 distributed among 10 partitions i.e. 10 elements/partition. map() transformation will call func 100 times to process these 100 elements but in case of mapPartitions(), func will be called once/partition i.e. 10 times.

  • Secondly, mapPartitions() holds the data in-memory i.e. it will store the result in memory until all the elements of the partition has been processed.

  • mapPartitions() will return the result only after it finishes processing of whole partition.

  • mapPartitions() requires an iterator input unlike map() transformation.

What is an iterator? (for new programmers) - An iterator is a way to access collection of elements one-by-one, its similar to collection of elements like List(), Array(), Dict() etc in few ways but the difference is that iterator doesn't load the whole collection of elements in memory at together. Instead iterator loads elements one after another. In Scala you access these elements with hasNext and Next operation.

For example,

scala> sc.parallelize(1 to 9, 3).map(x=>(x, "Hello")).collect

res3: Array[(Int, String)] = Array((1,Hello), (2,Hello), (3,Hello), (4,Hello), (5,Hello), (6,Hello), (7,Hello), (8,Hello), (9,Hello))

scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(Array("Hello").iterator)).collect

res7: Array[String] = Array(Hello, Hello, Hello)

scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(

res11: Array[Int] = Array(1, 4, 7)

In first example, I have applied map() transformation on dataset distributed between 3 partitions so that you can see function is called 9 times. In second example, when we applied mapPartitions(), you will notice it ran 3 times i.e. for each partition once. We had to convert string "Hello" into iterator because mapPartitions() takes iterator. In thirds step, I tried to get the iterator value to showcase the dataset. Note that next is always increasing value, so you can't step back. See this,

scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(,, "|").iterator)).collect

res18: Array[Any] = Array(1, 2, |, 4, 5, |, 7, 8, |)

In first call next value for partition 1 changed from 1 => 2 , for partition 2 it changed from 4 => 5 and similarly for partition 3 it changed from 7 => 8. You can keep this increasing until hasNext is False (hasNext is a property of iteration which tells you whether collection has ended or not, it returns you True or False based on items left in the collection). For example,

scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(, x.hasNext).iterator)).collect

res19: Array[AnyVal] = Array(1, true, 4, true, 7, true)

You can see hasNext is true because there are elements left in each partition. Now suppose we access all three elements from each partition, then hasNext will result false. For example,

scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(,,, x.hasNext).iterator)).collect