Spark mapPartitions - Similar to map() transformation but in this case function runs separately on each partition (block) of RDD unlike map() where it was running on each element of partition. Hence mapPartitions are also useful when you are looking for performance gain (calls your function once/partition not once/element).
Suppose you have elements from 1 to 100 distributed among 10 partitions i.e. 10 elements/partition. map() transformation will call func 100 times to process these 100 elements but in case of mapPartitions(), func will be called once/partition i.e. 10 times.
Secondly, mapPartitions() holds the data in-memory i.e. it will store the result in memory until all the elements of the partition has been processed.
mapPartitions() will return the result only after it finishes processing of whole partition.
mapPartitions() requires an iterator input unlike map() transformation.
What is an Iterator? An iterator is a way to access collection of elements one-by-one, its similar to collection of elements like List(), Array() etc in few ways but the difference is that iterator doesn't load the whole collection of elements in memory all together. Instead iterator loads elements one after another. In Scala you access these elements with hasNext and Next operation.
scala> sc.parallelize(1 to 9, 3).map(x=>(x, "Hello")).collect
scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(Array("Hello").iterator)).collect
res7: Array[String] = Array(Hello, Hello, Hello)
scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(x.next).iterator)).collect
res11: Array[Int] = Array(1, 4, 7)
In first example, I have applied map() transformation on dataset distributed between 3 partitions so that you can see function is called 9 times. In second example, when we applied mapPartitions(), you will notice it ran 3 times i.e. for each partition once. We had to convert string "Hello" into iterator because mapPartitions() takes iterator as input. In thirds step, I tried to get the iterator next value to show you the element. Note that next is always increasing value, so you can't step back.
scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(x.next,x.next, "|").iterator)).collect
In first call next value for partition 1 changed from 1 => 2 , for partition 2 it changed from 4 => 5 and similarly for partition 3 it changed from 7 => 8. You can keep this increasing until hasNext is False (hasNext is a property of iteration which tells you whether collection has ended or not, it returns you True or False based on items left in the collection).
scala> sc.parallelize(1 to 9, 3).mapPartitions(x=>(List(x.next, x.hasNext).iterator)).collect