Skip to content

Commit

Permalink
[SPARK-10708] Consolidate sort shuffle implementations
Browse files Browse the repository at this point in the history
There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
  • Loading branch information
JoshRosen committed Oct 22, 2015
1 parent 94e2064 commit f6d06ad
Show file tree
Hide file tree
Showing 30 changed files with 456 additions and 1,317 deletions.
Expand Up @@ -21,21 +21,30 @@
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.annotation.Nullable;

import scala.None$;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;

Expand All @@ -62,7 +71,7 @@
* <p>
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
*/
final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {

private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);

Expand All @@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
private final IndexShuffleBlockResolver shuffleBlockResolver;

/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
* we don't try deleting files, etc twice.
*/
private boolean stopping = false;

public BypassMergeSortShuffleWriter(
SparkConf conf,
BlockManager blockManager,
Partitioner partitioner,
ShuffleWriteMetrics writeMetrics,
Serializer serializer) {
IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf conf) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
this.numPartitions = partitioner.numPartitions();
this.blockManager = blockManager;
this.partitioner = partitioner;
this.writeMetrics = writeMetrics;
this.serializer = serializer;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.mapId = mapId;
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = new ShuffleWriteMetrics();
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
this.serializer = Serializer.getSerializer(dep.serializer());
this.shuffleBlockResolver = shuffleBlockResolver;
}

@Override
public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
Expand Down Expand Up @@ -124,13 +154,24 @@ public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}

partitionLengths =
writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

@Override
public long[] writePartitionedFile(
BlockId blockId,
TaskContext context,
File outputFile) throws IOException {
@VisibleForTesting
long[] getPartitionLengths() {
return partitionLengths;
}

/**
* Concatenate all of the per-partition files into a single combined file.
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
Expand Down Expand Up @@ -165,18 +206,33 @@ public long[] writePartitionedFile(
}

@Override
public void stop() throws IOException {
if (partitionWriters != null) {
try {
for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
File file = writer.revertPartialWritesAndClose();
if (!file.delete()) {
logger.error("Error while deleting file {}", file.getAbsolutePath());
public Option<MapStatus> stop(boolean success) {
if (stopping) {
return None$.empty();
} else {
stopping = true;
if (success) {
if (mapStatus == null) {
throw new IllegalStateException("Cannot call stop(true) without having called write()");
}
return Option.apply(mapStatus);
} else {
// The map task failed, so delete our output data.
if (partitionWriters != null) {
try {
for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
File file = writer.revertPartialWritesAndClose();
if (!file.delete()) {
logger.error("Error while deleting file {}", file.getAbsolutePath());
}
}
} finally {
partitionWriters = null;
}
}
} finally {
partitionWriters = null;
shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
return None$.empty();
}
}
}
Expand Down
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
Expand Down
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

import javax.annotation.Nullable;
import java.io.File;
Expand Down Expand Up @@ -48,7 +48,7 @@
* <p>
* Incoming records are appended to data pages. When all records have been inserted (or when the
* current thread's shuffle memory limit is reached), the in-memory records are sorted according to
* their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
* their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
* written to a single output file (or multiple files, if we've spilled). The format of the output
* files is the same as the format of the final output file written by
* {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
Expand All @@ -59,9 +59,9 @@
* spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
* specialized merge procedure that avoids extra serialization/deserialization.
*/
final class UnsafeShuffleExternalSorter {
final class ShuffleExternalSorter {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);

@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
Expand All @@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter {
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
private long numRecordsInsertedSinceLastSpill = 0;

/** Force this sorter to spill when there are this many elements in memory. For testing only */
private final long numElementsForSpillThreshold;

/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;
Expand All @@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter {
private long peakMemoryUsedBytes;

// These variables are reset after spilling:
@Nullable private UnsafeShuffleInMemorySorter inMemSorter;
@Nullable private ShuffleInMemorySorter inMemSorter;
@Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;

public UnsafeShuffleExternalSorter(
public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
Expand All @@ -117,6 +121,8 @@ public UnsafeShuffleExternalSorter(
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.pageSizeBytes = (int) Math.min(
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
Expand All @@ -140,7 +146,8 @@ private void initializeForWriting() throws IOException {
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}

this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
numRecordsInsertedSinceLastSpill = 0;
}

/**
Expand All @@ -166,7 +173,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
}

// This call performs the actual sort.
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();

// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
Expand Down Expand Up @@ -406,6 +413,10 @@ public void insertRecord(
int lengthInBytes,
int partitionId) throws IOException {

if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
spill();
}

growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
Expand Down Expand Up @@ -453,6 +464,7 @@ public void insertRecord(
recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, partitionId);
numRecordsInsertedSinceLastSpill += 1;
}

/**
Expand Down
Expand Up @@ -15,13 +15,13 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

import java.util.Comparator;

import org.apache.spark.util.collection.Sorter;

final class UnsafeShuffleInMemorySorter {
final class ShuffleInMemorySorter {

private final Sorter<PackedRecordPointer, long[]> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
Expand All @@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
*/
private int pointerArrayInsertPosition = 0;

public UnsafeShuffleInMemorySorter(int initialSize) {
public ShuffleInMemorySorter(int initialSize) {
assert (initialSize > 0);
this.pointerArray = new long[initialSize];
this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
}

public void expandPointerArray() {
Expand Down Expand Up @@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) {
/**
* An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
*/
public static final class UnsafeShuffleSorterIterator {
public static final class ShuffleSorterIterator {

private final long[] pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;

public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
Expand All @@ -117,8 +117,8 @@ public void loadNext() {
/**
* Return an iterator over record pointers in sorted order.
*/
public UnsafeShuffleSorterIterator getSortedIterator() {
public ShuffleSorterIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
}
}
Expand Up @@ -15,15 +15,15 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

import org.apache.spark.util.collection.SortDataFormat;

final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {

public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();

private UnsafeShuffleSortDataFormat() { }
private ShuffleSortDataFormat() { }

@Override
public PackedRecordPointer getKey(long[] data, int pos) {
Expand Down

0 comments on commit f6d06ad

Please sign in to comment.