Skip to content

Commit f6d06ad

Browse files
committedOct 22, 2015
[SPARK-10708] Consolidate sort shuffle implementations
There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together. Author: Josh Rosen <joshrosen@databricks.com> Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
1 parent 94e2064 commit f6d06ad

30 files changed

+456
-1317
lines changed
 

‎core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,30 @@
2121
import java.io.FileInputStream;
2222
import java.io.FileOutputStream;
2323
import java.io.IOException;
24+
import javax.annotation.Nullable;
2425

26+
import scala.None$;
27+
import scala.Option;
2528
import scala.Product2;
2629
import scala.Tuple2;
2730
import scala.collection.Iterator;
2831

32+
import com.google.common.annotations.VisibleForTesting;
2933
import com.google.common.io.Closeables;
3034
import org.slf4j.Logger;
3135
import org.slf4j.LoggerFactory;
3236

3337
import org.apache.spark.Partitioner;
38+
import org.apache.spark.ShuffleDependency;
3439
import org.apache.spark.SparkConf;
3540
import org.apache.spark.TaskContext;
3641
import org.apache.spark.executor.ShuffleWriteMetrics;
42+
import org.apache.spark.scheduler.MapStatus;
43+
import org.apache.spark.scheduler.MapStatus$;
3744
import org.apache.spark.serializer.Serializer;
3845
import org.apache.spark.serializer.SerializerInstance;
46+
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
47+
import org.apache.spark.shuffle.ShuffleWriter;
3948
import org.apache.spark.storage.*;
4049
import org.apache.spark.util.Utils;
4150

@@ -62,7 +71,7 @@
6271
* <p>
6372
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
6473
*/
65-
final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
74+
final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
6675

6776
private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
6877

@@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
7281
private final BlockManager blockManager;
7382
private final Partitioner partitioner;
7483
private final ShuffleWriteMetrics writeMetrics;
84+
private final int shuffleId;
85+
private final int mapId;
7586
private final Serializer serializer;
87+
private final IndexShuffleBlockResolver shuffleBlockResolver;
7688

7789
/** Array of file writers, one for each partition */
7890
private DiskBlockObjectWriter[] partitionWriters;
91+
@Nullable private MapStatus mapStatus;
92+
private long[] partitionLengths;
93+
94+
/**
95+
* Are we in the process of stopping? Because map tasks can call stop() with success = true
96+
* and then call stop() with success = false if they get an exception, we want to make sure
97+
* we don't try deleting files, etc twice.
98+
*/
99+
private boolean stopping = false;
79100

80101
public BypassMergeSortShuffleWriter(
81-
SparkConf conf,
82102
BlockManager blockManager,
83-
Partitioner partitioner,
84-
ShuffleWriteMetrics writeMetrics,
85-
Serializer serializer) {
103+
IndexShuffleBlockResolver shuffleBlockResolver,
104+
BypassMergeSortShuffleHandle<K, V> handle,
105+
int mapId,
106+
TaskContext taskContext,
107+
SparkConf conf) {
86108
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
87109
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
88110
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
89-
this.numPartitions = partitioner.numPartitions();
90111
this.blockManager = blockManager;
91-
this.partitioner = partitioner;
92-
this.writeMetrics = writeMetrics;
93-
this.serializer = serializer;
112+
final ShuffleDependency<K, V, V> dep = handle.dependency();
113+
this.mapId = mapId;
114+
this.shuffleId = dep.shuffleId();
115+
this.partitioner = dep.partitioner();
116+
this.numPartitions = partitioner.numPartitions();
117+
this.writeMetrics = new ShuffleWriteMetrics();
118+
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
119+
this.serializer = Serializer.getSerializer(dep.serializer());
120+
this.shuffleBlockResolver = shuffleBlockResolver;
94121
}
95122

96123
@Override
97-
public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
124+
public void write(Iterator<Product2<K, V>> records) throws IOException {
98125
assert (partitionWriters == null);
99126
if (!records.hasNext()) {
127+
partitionLengths = new long[numPartitions];
128+
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
129+
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
100130
return;
101131
}
102132
final SerializerInstance serInstance = serializer.newInstance();
@@ -124,13 +154,24 @@ public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
124154
for (DiskBlockObjectWriter writer : partitionWriters) {
125155
writer.commitAndClose();
126156
}
157+
158+
partitionLengths =
159+
writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
160+
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
161+
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
127162
}
128163

129-
@Override
130-
public long[] writePartitionedFile(
131-
BlockId blockId,
132-
TaskContext context,
133-
File outputFile) throws IOException {
164+
@VisibleForTesting
165+
long[] getPartitionLengths() {
166+
return partitionLengths;
167+
}
168+
169+
/**
170+
* Concatenate all of the per-partition files into a single combined file.
171+
*
172+
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
173+
*/
174+
private long[] writePartitionedFile(File outputFile) throws IOException {
134175
// Track location of the partition starts in the output file
135176
final long[] lengths = new long[numPartitions];
136177
if (partitionWriters == null) {
@@ -165,18 +206,33 @@ public long[] writePartitionedFile(
165206
}
166207

167208
@Override
168-
public void stop() throws IOException {
169-
if (partitionWriters != null) {
170-
try {
171-
for (DiskBlockObjectWriter writer : partitionWriters) {
172-
// This method explicitly does _not_ throw exceptions:
173-
File file = writer.revertPartialWritesAndClose();
174-
if (!file.delete()) {
175-
logger.error("Error while deleting file {}", file.getAbsolutePath());
209+
public Option<MapStatus> stop(boolean success) {
210+
if (stopping) {
211+
return None$.empty();
212+
} else {
213+
stopping = true;
214+
if (success) {
215+
if (mapStatus == null) {
216+
throw new IllegalStateException("Cannot call stop(true) without having called write()");
217+
}
218+
return Option.apply(mapStatus);
219+
} else {
220+
// The map task failed, so delete our output data.
221+
if (partitionWriters != null) {
222+
try {
223+
for (DiskBlockObjectWriter writer : partitionWriters) {
224+
// This method explicitly does _not_ throw exceptions:
225+
File file = writer.revertPartialWritesAndClose();
226+
if (!file.delete()) {
227+
logger.error("Error while deleting file {}", file.getAbsolutePath());
228+
}
229+
}
230+
} finally {
231+
partitionWriters = null;
176232
}
177233
}
178-
} finally {
179-
partitionWriters = null;
234+
shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
235+
return None$.empty();
180236
}
181237
}
182238
}

‎core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java renamed to ‎core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.shuffle.unsafe;
18+
package org.apache.spark.shuffle.sort;
1919

2020
/**
2121
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.

‎core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java renamed to ‎core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.shuffle.unsafe;
18+
package org.apache.spark.shuffle.sort;
1919

2020
import javax.annotation.Nullable;
2121
import java.io.File;
@@ -48,7 +48,7 @@
4848
* <p>
4949
* Incoming records are appended to data pages. When all records have been inserted (or when the
5050
* current thread's shuffle memory limit is reached), the in-memory records are sorted according to
51-
* their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
51+
* their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
5252
* written to a single output file (or multiple files, if we've spilled). The format of the output
5353
* files is the same as the format of the final output file written by
5454
* {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
@@ -59,9 +59,9 @@
5959
* spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
6060
* specialized merge procedure that avoids extra serialization/deserialization.
6161
*/
62-
final class UnsafeShuffleExternalSorter {
62+
final class ShuffleExternalSorter {
6363

64-
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
64+
private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
6565

6666
@VisibleForTesting
6767
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
@@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter {
7676
private final BlockManager blockManager;
7777
private final TaskContext taskContext;
7878
private final ShuffleWriteMetrics writeMetrics;
79+
private long numRecordsInsertedSinceLastSpill = 0;
80+
81+
/** Force this sorter to spill when there are this many elements in memory. For testing only */
82+
private final long numElementsForSpillThreshold;
7983

8084
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
8185
private final int fileBufferSizeBytes;
@@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter {
9498
private long peakMemoryUsedBytes;
9599

96100
// These variables are reset after spilling:
97-
@Nullable private UnsafeShuffleInMemorySorter inMemSorter;
101+
@Nullable private ShuffleInMemorySorter inMemSorter;
98102
@Nullable private MemoryBlock currentPage = null;
99103
private long currentPagePosition = -1;
100104
private long freeSpaceInCurrentPage = 0;
101105

102-
public UnsafeShuffleExternalSorter(
106+
public ShuffleExternalSorter(
103107
TaskMemoryManager memoryManager,
104108
ShuffleMemoryManager shuffleMemoryManager,
105109
BlockManager blockManager,
@@ -117,6 +121,8 @@ public UnsafeShuffleExternalSorter(
117121
this.numPartitions = numPartitions;
118122
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
119123
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
124+
this.numElementsForSpillThreshold =
125+
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
120126
this.pageSizeBytes = (int) Math.min(
121127
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
122128
this.maxRecordSizeBytes = pageSizeBytes - 4;
@@ -140,7 +146,8 @@ private void initializeForWriting() throws IOException {
140146
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
141147
}
142148

143-
this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
149+
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
150+
numRecordsInsertedSinceLastSpill = 0;
144151
}
145152

146153
/**
@@ -166,7 +173,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
166173
}
167174

168175
// This call performs the actual sort.
169-
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
176+
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
170177
inMemSorter.getSortedIterator();
171178

172179
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
@@ -406,6 +413,10 @@ public void insertRecord(
406413
int lengthInBytes,
407414
int partitionId) throws IOException {
408415

416+
if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
417+
spill();
418+
}
419+
409420
growPointerArrayIfNecessary();
410421
// Need 4 bytes to store the record length.
411422
final int totalSpaceRequired = lengthInBytes + 4;
@@ -453,6 +464,7 @@ public void insertRecord(
453464
recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
454465
assert(inMemSorter != null);
455466
inMemSorter.insertRecord(recordAddress, partitionId);
467+
numRecordsInsertedSinceLastSpill += 1;
456468
}
457469

458470
/**

‎core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java renamed to ‎core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.shuffle.unsafe;
18+
package org.apache.spark.shuffle.sort;
1919

2020
import java.util.Comparator;
2121

2222
import org.apache.spark.util.collection.Sorter;
2323

24-
final class UnsafeShuffleInMemorySorter {
24+
final class ShuffleInMemorySorter {
2525

2626
private final Sorter<PackedRecordPointer, long[]> sorter;
2727
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
4444
*/
4545
private int pointerArrayInsertPosition = 0;
4646

47-
public UnsafeShuffleInMemorySorter(int initialSize) {
47+
public ShuffleInMemorySorter(int initialSize) {
4848
assert (initialSize > 0);
4949
this.pointerArray = new long[initialSize];
50-
this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
50+
this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
5151
}
5252

5353
public void expandPointerArray() {
@@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) {
9292
/**
9393
* An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
9494
*/
95-
public static final class UnsafeShuffleSorterIterator {
95+
public static final class ShuffleSorterIterator {
9696

9797
private final long[] pointerArray;
9898
private final int numRecords;
9999
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
100100
private int position = 0;
101101

102-
public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
102+
public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
103103
this.numRecords = numRecords;
104104
this.pointerArray = pointerArray;
105105
}
@@ -117,8 +117,8 @@ public void loadNext() {
117117
/**
118118
* Return an iterator over record pointers in sorted order.
119119
*/
120-
public UnsafeShuffleSorterIterator getSortedIterator() {
120+
public ShuffleSorterIterator getSortedIterator() {
121121
sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
122-
return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
122+
return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
123123
}
124124
}

‎core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java renamed to ‎core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.shuffle.unsafe;
18+
package org.apache.spark.shuffle.sort;
1919

2020
import org.apache.spark.util.collection.SortDataFormat;
2121

22-
final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
22+
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
2323

24-
public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
24+
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
2525

26-
private UnsafeShuffleSortDataFormat() { }
26+
private ShuffleSortDataFormat() { }
2727

2828
@Override
2929
public PackedRecordPointer getKey(long[] data, int pos) {

0 commit comments

Comments
 (0)
Please sign in to comment.