Skip to content

TensorOnSpark for Distributed Deep Learning

liangfengsid edited this page Aug 6, 2016 · 8 revisions

Overview

TensorOnSpark is the scalable distributed framework for deep learning. It can run the TensorFlow programs seamlessly on top of Spark with the novel notion SparkSession, the distributed machine learning context. Users can write the machine learning and deep neural network programs with ordinary TensorFlow interface, and run them in the distributed manner without aware of the underlying mechanisms. Compared to the distributed mode of the Tensorflow nodes, TensorOnSpark does better computer resource management and faster large-volume data processing with the reliable and scalable distributed systems, e.g., Hadoop and Spark, and less network traffic.

TensorOnSpark is distinguished for the following features:

  • Easy large-volume data preparation;
  • Efficient computer resource allocation;
  • Reliable and flexible parallel parameter updating.
  • Highly compatible with Tensorflow
  • Low network traffic and high learning accuracy.

SparkSession

SparkSession is the core module of TensorOnSpark. Each TensorFlow session (including the graph and the associated parameter values) corresponds to an instance of SparkSession. SparkSession exposes to the users the single instance of the TensorFlow model graph and model parameters to the users. With SparkSession, users can build the training model exactly as if to build the TensorFlow learning model in the single computer, and let SparkSession handle the distributed training on the distributed RDD datasets and the synchronization of the distributed model parameter values.

Architecture

The architecture of SparkSession is shown in Fig. 1.

SparkSession Architecture
Fig. 1 Architecture of SparkSession

SparkSession applies the master-slave architecture, where the master is the application master of a Spark job and the slaves are the Spark executors. The master maintains the single instance of the TensorFlow graph and parameter values and hosts the Tensor Parameter Server (TPS) for parallel parameter updating. Each worker feeds forward a partition of the RDD data as the training input and synchronizes with TPS periodically to update the trained parameter values.

How it Work

Before running a learning program, the user needs to build up the learning model and prepare the training input data for a model.

The construction of the model graph in the SparkSession master is exactly the same as building the TensorFlow graph, where the users define the variables and operations as the graph node with connections. In the data preparation phase with SparkSession, the training input data are usually stored in the distributed storage systems such as HDFS and HBase. Users can use the Spark to import and process the data in the RDD format, where each item of the RDD is a data entry of the input Tensor.

The working procedures of training after the TensorFlow model is built are shown as follows:

  1. The SparkSession master persists the model graph including the initial parameter values to HDFS for further retrieval from the Spark executors.
  2. The master broadcasts the information of TPS and the metadata of the input Tensor (or feed) and the output Tensor (or fetch) to the executors.
  3. The executors retrieve the model graph from HDFS and build up the local graph, which is consistent with that in the master.
  4. Each executor feeds the prepared data in the corresponding RDD partition to the graph and updates the local parameter values.
  5. For every designated training steps in an executor, the executor pushes the new parameter values to TPS and gets back the newly updated parameter values from TPS.
  6. An epoch of training ends when every executor runs out of the whole partition of input data.
  7. The input RDD data can be repartitioned and re-sorted for the next epoch of training, starting from Step 4.

In Step 5, the executors update the parameter values with the TPS at every milestone, but different executors need not wait for the each other to synchronize the parameters at the same milestone. Instead, different executors updates with TPS in the asynchronous style. In other words, executors of different milestones can update the parameters with TPS at the same time. TPS controls the asynchronous parameter update via the flexible parameter combiners. TPS provides several efficient built-in combiners and allows user-defined ones.

Example

The installation instruction can be found in the TensorOnSpark homepage. We demonstrate the use of TensorSpark with the example of MNIST. MNIST is the learning program that recognizes the handwriting digits from the images. We show a part of the MNIST program in TensorOnSpark and explain the codes in python. The full version can be found in spark_mnist.py.

# Extract the images and labels from the file in HDFS
image_rdd = mnist.extract_images(sc, mnist.train_image_path)
label_rdd = mnist.extract_labels(sc, mnist.train_label_path, num_class=10, one_hot=True)

# image_label is the rdd where each entry is the tuple of (image, label)
image_label_rdd = image_rdd.join(label_rdd, numPartitions=num_partition).mapPartitions(mnist.flatten_image_label).cache()

# Build up the normal TensorFlow graph and initialize the variables. This procedure is exactly the same as the noraml TensorFlow program
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# Indicate the feed and fetch variables for a SparkSession run. 
feed_name_list = [x.name, y_.name]
param_list = [W, b]

# initial the SparkSession with the Spark context, TensorFlow session and other running configuration information.
spark_sess = sps.SparkSession(sc, sess, user='liangfengsid', name='spark_mnist', server_host='localhost', server_port=10080, sync_interval=100, batch_size=100)

# run the SparkSession and repartition between epochs
partitioner = par.RandomPartitioner(num_partition)
for i in range(num_epoch):
	spark_sess.run(train_step, feed_rdd=image_label_rdd, feed_name_list=feed_name_list, param_list=param_list, shuffle_within_partition=True)
	if i != num_epoch-1:
		temp_image_label_rdd = image_label_rdd.partitionBy(num_partition, partitioner).cache()
		image_label_rdd.unpersist()
		image_label_rdd = temp_image_label_rdd

(Cont.)Further details introduction of the running mechanism and design consideration are coming soon.

Clone this wiki locally