Shuffle涉及的类型信息使用踪迹
本片文档基于final-dev分支,主要介绍在Shuffle阶段,由driver端设置的相关类型信息在worker端如何使用。
driver端传递类型信息
该分支在生成ShuffleDependency时进行Shuffle Read和Shuffle Write的类型设置,具体文件为Dependency.scala,具体位置为第86行:
shuffleHandle.setTypes((0, 2), (0, 2))
worker端调用类型信息
Shuffle Write过程
在ShuffleMapTask的runTask函数中,首先反序列化得到rdd和dependency。
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
然后通过dependency中的ShuffleHanle获取ShuffleWriter,在RDD计算结束之后,使用该ShuffleWriter进行shuffle write。
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
注意:groupByKey引发的shuffle write操作会走Spark的原分支,reduceBykey引发的shuffle write操作才会走我们的分支。
在RDD计算完毕之后,worker端会进行Shuffle Write操作,调用UnsafeShuffleWriter类中的write函数,在由reduceByKey引发Shuffle操作时,会调用insertOrUpdateRecordIntoSorter函数,在该函数中对key和value进行了字节化,使用的是FlintSerializer,该序列化器会持有ShuffleHanle。
FlintSerializer fs = new FlintSerializer(handle);
目前由于转化时会传入key和value的实例,所以未使用ShuffleHandle中封装的类型信息,直接对实例使用了类型匹配来转化,详见FlintSerializationStream类中重载的writeKey和writeValue函数。
override def writeKey[T: ClassTag](key: T): SerializationStream
override def writeValue[T: ClassTag](value: T): SerializationStream
对key和value进行字节化之后,调用UnsafeShuffleExternalSorter类的insertOrUpdateRecord函数,该函数使用ShuffleHandle中封装的类型信息实现堆外聚合,详见insertOrUpdateRecord函数相关代码。
switch(Integer.parseInt(handle.getWriteType()._2().toString())){
case 0 :
int oriValue1 = PlatformDependent.UNSAFE.getInt(loc.getPage(), loc.getPosition() + 4 + keyLengthInBytes);
PlatformDependent.UNSAFE.putInt(loc.getPage(), loc.getPosition() + 4 + keyLengthInBytes, oriValue1 + PlatformDependent.UNSAFE.getInt(valueBaseObject,valueBaseOffset));
break;
case 1 :
long oriValue2 = PlatformDependent.UNSAFE.getLong(loc.getPage(), loc.getPosition() + 4 + keyLengthInBytes);
PlatformDependent.UNSAFE.putLong(loc.getPage(), loc.getPosition() + 4 + keyLengthInBytes, oriValue2 + PlatformDependent.UNSAFE.getLong(valueBaseObject,valueBaseOffset));
break;
case 2 :
double oriValue3 = PlatformDependent.UNSAFE.getDouble(loc.getPage(), loc.getPosition() + 4 + keyLengthInBytes);
PlatformDependent.UNSAFE.putDouble(loc.getPage(), loc.getPosition() + 4 + keyLengthInBytes, oriValue3 + PlatformDependent.UNSAFE.getDouble(valueBaseObject,valueBaseOffset));
break;
default :
break;
}
至此,shuffle write阶段对ShuffleHandle中封装的类型信息使用结束。
Shuffle Read过程
Shuffle Read发生在ShuffledRDD重载的compute方法中,该方法中首先获取了该ShuffledRDD的dependency。
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
然后通过该dependency中的ShuffleHandle获取ShuffleReader,并使用该ShuffleReader进行shuffle read。
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
进行Shuffle Read时会调用到类FlintHashShuffleReader中重载的read函数,注意到此函数中第51行对类型进行了改写:
dep.shuffleHandle.setTypes((0, 0), (0, 0))
这是为了能够运行DecaPR.scala而做的手动设置,在整合之后可删除。
shuffle read分为两种情况,groupByKey引发的和reduceBykey引发的。groupByKey引发的shuffle read在fetch数据时走的时Spark的原分支,使用的Serializer为kryoSerializer。reduceByKey引发的shuffle read在fetch数据时走的我们的分支,使用的Serializer为FlintSerializer,该序列化器使用ShuffleHandle进行初始化。
ser = new FlintSerializer(handle)
fetch时调用FlintDesrializationStream类的readKey和readValue函数。这两个函数中使用了ShuffleHandle中封装的类型信息进行数据fetch。
override def readKey[T: ClassTag](): T = {
handle.getReadType._1 match {
case 0 =>
input.readInt().asInstanceOf[T]
case 1 =>
input.readLong().asInstanceOf[T]
case 2 =>
input.readDouble().asInstanceOf[T]
}
}
override def readValue[T: ClassTag](): T = {
handle.getReadType._2 match {
case 0 =>
input.readInt().asInstanceOf[T]
case 1 =>
input.readLong().asInstanceOf[T]
case 2 =>
input.readDouble().asInstanceOf[T]
case 8 =>
(input.readInt(), input.readInt(), input.readInt(), input.readInt()).asInstanceOf[T]
}
}
接下来会对数据进行聚合,分为groupByKey的聚合和reduceByKey的聚合。
reduceByKey的聚合函数–combineCombinersBykey
combineCombinersByKey使用UnsafeFixedFlintMap类进行堆外聚合,该类使用ShuffleHandle初始化。
val combiners = new UnsafeFixedFlintMap[K, C](1024,context.taskMemoryManager(),false,handle)
UnsafeFixedFlintMap类中包含了一个用于字节化key和value的FlintSerializer,使用ShuffleHandle初始化,该类中的getKey和getValue方法同样使用ShuffleHandle中的类型信息来从堆外重建对象。
def getKey[K](address: MemoryLocation): K = {
var key: Any = null.asInstanceOf[Any]
handle.getReadType._1 match {
case 0 =>
key = PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset)
case 1 =>
key = PlatformDependent.UNSAFE.getLong(address.getBaseObject, address.getBaseOffset)
case 2 =>
key = PlatformDependent.UNSAFE.getDouble(address.getBaseObject, address.getBaseOffset)
case _ =>
require(true, s"type ${handle.getReadType._1} missmatch")
}
key.asInstanceOf[K]
}
def getValue[V](address: MemoryLocation): V = {
var value: Any = null.asInstanceOf[Any]
handle.getReadType._2 match {
case 0 =>
value = PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset)
case 1 =>
value = PlatformDependent.UNSAFE.getLong(address.getBaseObject, address.getBaseOffset)
case 2 =>
value = PlatformDependent.UNSAFE.getDouble(address.getBaseObject, address.getBaseOffset)
case 8 =>
val getValue: (Int, Int, Int, Int) = (
PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset),
PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset + 4),
PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset + 8),
PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset + 12)
)
value = getValue
case _ =>
require(true, s"type ${handle.getReadType._2} missmatch")
}
value.asInstanceOf[V]
}
groupByKey的聚合函数–combineValuesByKey
combineCombinersByKey使用UnsafeUnfixedFlintMap类进行堆外聚合,该类使用ShuffleHandle初始化。
val combiners = new UnsafeUnfixedFlintMap[K, C](1024,context.taskMemoryManager(),false,handle)
UnsafeUnfixedFlintMap类中同样包含了一个用于字节化key和value的FlintSerializer,使用ShuffleHandle初始化,该类中的getKey和getRealValue方法同样使用ShuffleHandle中的类型信息来从堆外重建对象。
def getKey(address: MemoryLocation): K = {
var key: Any = null.asInstanceOf[Any]
handle.getReadType._1 match {
case 0 =>
key = PlatformDependent.UNSAFE.getInt(address.getBaseObject, address.getBaseOffset)
case 1 =>
key = PlatformDependent.UNSAFE.getLong(address.getBaseObject, address.getBaseOffset)
case 2 =>
key = PlatformDependent.UNSAFE.getDouble(address.getBaseObject, address.getBaseOffset)
case _ =>
require(true, s"type ${handle.getReadType._1} missmatch")
}
key.asInstanceOf[K]
}
def getRealValue(address: Long): V = {
// val size: Int = PlatformDependent.UNSAFE.getInt(null, address)
val sizeCount: Int = PlatformDependent.UNSAFE.getInt(null, address + 4)
handle.getReadType._2 match {
case 0 =>
val realValue = new CompactBuffer[Int]
for(i <- 8 until sizeCount by 4)
realValue += PlatformDependent.UNSAFE.getInt(null, address + i)
realValue.asInstanceOf[V]
case 1 =>
val realValue = new CompactBuffer[Long]
for(i <- 8 until sizeCount by 8)
realValue += PlatformDependent.UNSAFE.getLong(null, address + i)
realValue.asInstanceOf[V]
case 2 =>
val realValue = new CompactBuffer[Double]
for(i <- 8 until sizeCount by 8)
realValue += PlatformDependent.UNSAFE.getDouble(null, address + i)
realValue.asInstanceOf[V]
case _ =>
require(true, s"type ${handle.getReadType._2} missmatch")
null.asInstanceOf[V]
}
}