Closed
Description
Environment info
Operating System: OSX on CPU
Tensorflow 1.0.0
Problem
Hello, i've been trying to use tf.assign
with a tf.control_dependencies
scheme when changing the shape on the fly.
import tensorflow as tf
# I define a "shape-able" Variable
x = tf.Variable(
[],
dtype=tf.int32,
validate_shape=False,
trainable=False
)
# I build a new shape and assign it to x
concat = tf.concat([x, [0]], 0)
assign_op = tf.assign(x, concat, validate_shape=False)
with tf.control_dependencies([assign_op]):
# I print x after the assignment
# Note that the Print call is on "x" and NOT "assign_op"
print_op_dep = tf.Print(x, data=[x], message="print_op_dep:")
# The assign_op is called, but it seems that print statement happens
# before the assignment
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(3):
sess.run(print_op_dep)
Outputs:
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]
I would expect:
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0 0]
Is this a bug ?
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
No branches or pull requests
Activity
poxvoculi commentedon Feb 23, 2017
Interesting. I modified your program as follows:
If "y = x" is used instead of "y = assign_op", I get this output:
Setting "y = assign_op", I get what you expected.
It appears as though the control_dependencies construct correctly forces assign_op to execute, but the new value isn't really accessible to an evaluation of x until later. This surpasses my understanding. Summon @mrry.
mrry commentedon Feb 23, 2017
This is a subtle corner of the
tf.Variable
semantics, which has tripped up a few people. The main thing to note is that when you first read atf.Variable
—in this case, when it is used as part of the argument totf.concat()
—the value read is "cached".What does it mean for a value to be "cached"? In the code, it's fed to a
tf.identity()
, which implicitly dereferences the ref-typed variable tensor, but then (perhaps surprisingly) returns a value-typed tensor that aliases the buffer used for the variable. This behavior was chosen for distributed (or multi-device) execution, where the aliasing isn't usually noticeable because the reader is typically on a remote device, and the buffer will be copied between devices anyway.However, when you assign a tensor of a different shape to a
tf.Variable
, the snapshot and current variable value can no longer be aliases, because they're buffers of a different size.(Aside: If you'd done
tf.assign_add(x, x + 1)
(or something else that preserved the shape ofx
) you would see things happen in the order you expected, because everything happens on the same device, and the "snapshot" remains an alias of the underlying buffer.) Thetf.Print()
op gets the old snapshot value, and prints that.How can you avoid this? One way is to force an explicit
x.read_value()
, which forces a new snapshot to be taken, respecting the control dependencies. Changing your program as follows will give the expected output:The output is:
mrry commentedon Feb 23, 2017
Of course, these semantics are not very intuitive, and there's no way you'd have guessed that from the documentation. @alextp is working on a new version of variables that will have more sensible semantics. I'll let him comment on how things will look in the brave new world.
morgangiraud commentedon Feb 23, 2017
👍
Thanks a lot for this fast answer. I understand now and i've been testing this successfully.
Also, one last question on my side:
Is using
tf.assign
in this way (especially the fact that i change the shape) leads to poor optimization on TensorFlow and so, poor performance ?mrry commentedon Feb 23, 2017
Glad to hear it!
As for the performance impact, it's hard to say. At the level of individual assign calls, TensorFlow doesn't do much to optimize your code, so you aren't necessarily missing any optimizations. Concatenating and copying like you do in that code snippet will have quadratic time complexity, but I'm not sure if you're going to be doing that in such a tight loop that it matters :). (If you find yourself concatenating dynamic lists of tensors a lot, you might be interested in
tf.TensorArray
instead... it was introduced in part to avoid doing quadratic concatenation when accumulating loop state in atf.while_loop()
.)It is possible that having varying-shape variables will lead to e.g. more unknowns in shape inference, which could inhibit some nice optimizations that are possible when the shape of a tensor is static, but I assume you have a reason for wanting to change the shape of a variable, so some amount of dynamism is probably necessary.
morgangiraud commentedon Feb 23, 2017
All right, thanks for those insights!