Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

control_dependencies and assign new shape not working (using validate_shape=False) #7782

Closed
morgangiraud opened this issue Feb 22, 2017 · 6 comments
Assignees
Labels
type:support Support issues

Comments

@morgangiraud
Copy link

Environment info

Operating System: OSX on CPU
Tensorflow 1.0.0

Problem

Hello, i've been trying to use tf.assign with a tf.control_dependencies scheme when changing the shape on the fly.

import tensorflow as tf

# I define a "shape-able" Variable
x = tf.Variable(
    [], 
    dtype=tf.int32,
    validate_shape=False,
    trainable=False
)
# I build a new shape and assign it to x
concat = tf.concat([x, [0]], 0)
assign_op = tf.assign(x, concat, validate_shape=False)

with tf.control_dependencies([assign_op]):
    # I print x after the assignment
    # Note that the Print call is on "x" and NOT "assign_op"
    print_op_dep = tf.Print(x, data=[x], message="print_op_dep:")
    # The assign_op is called, but it seems that print statement happens
    # before the assignment

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(3):
        sess.run(print_op_dep)

Outputs:

I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]

I would expect:

I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0 0]

Is this a bug ?

@poxvoculi
Copy link
Contributor

Interesting. I modified your program as follows:

import tensorflow as tf

x = tf.Variable(
    [], 
    dtype=tf.int32,
    validate_shape=False,
    trainable=False
)
x_alias = tf.Print(x, data=[x], message="x_alias")
concat = tf.concat([x_alias, [0]], 0)
concat_alias = tf.Print(concat, data=[concat], message="concat_alias")
assign_op = tf.assign(x, concat_alias, validate_shape=False)

with tf.control_dependencies([assign_op]):
    y = assign_op
    # y = x
    print_op_dep = tf.Print(y, data=[y], message="print_op_dep:")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(3):
        print sess.run(print_op_dep)

If "y = x" is used instead of "y = assign_op", I get this output:

I tensorflow/core/kernels/logging_ops.cc:79] x_alias[]
I tensorflow/core/kernels/logging_ops.cc:79] concat_alias[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[]
[]
I tensorflow/core/kernels/logging_ops.cc:79] x_alias[0]
I tensorflow/core/kernels/logging_ops.cc:79] concat_alias[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
[0]
I tensorflow/core/kernels/logging_ops.cc:79] x_alias[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] concat_alias[0 0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]
[0 0]

Setting "y = assign_op", I get what you expected.

I tensorflow/core/kernels/logging_ops.cc:79] x_alias[]
I tensorflow/core/kernels/logging_ops.cc:79] concat_alias[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
[0]
I tensorflow/core/kernels/logging_ops.cc:79] x_alias[0]
I tensorflow/core/kernels/logging_ops.cc:79] concat_alias[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]
[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] x_alias[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] concat_alias[0 0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0 0]
[0 0 0]

It appears as though the control_dependencies construct correctly forces assign_op to execute, but the new value isn't really accessible to an evaluation of x until later. This surpasses my understanding. Summon @mrry.

@poxvoculi poxvoculi added stat:awaiting tensorflower Status - Awaiting response from tensorflower type:support Support issues labels Feb 23, 2017
@mrry
Copy link
Contributor

mrry commented Feb 23, 2017

This is a subtle corner of the tf.Variable semantics, which has tripped up a few people. The main thing to note is that when you first read a tf.Variable—in this case, when it is used as part of the argument to tf.concat()—the value read is "cached".

What does it mean for a value to be "cached"? In the code, it's fed to a tf.identity(), which implicitly dereferences the ref-typed variable tensor, but then (perhaps surprisingly) returns a value-typed tensor that aliases the buffer used for the variable. This behavior was chosen for distributed (or multi-device) execution, where the aliasing isn't usually noticeable because the reader is typically on a remote device, and the buffer will be copied between devices anyway.

However, when you assign a tensor of a different shape to a tf.Variable, the snapshot and current variable value can no longer be aliases, because they're buffers of a different size.
(Aside: If you'd done tf.assign_add(x, x + 1) (or something else that preserved the shape of x) you would see things happen in the order you expected, because everything happens on the same device, and the "snapshot" remains an alias of the underlying buffer.) The tf.Print() op gets the old snapshot value, and prints that.

How can you avoid this? One way is to force an explicit x.read_value(), which forces a new snapshot to be taken, respecting the control dependencies. Changing your program as follows will give the expected output:

import tensorflow as tf

# I define a "shape-able" Variable                                                                                                                             
x = tf.Variable([], dtype=tf.int32, validate_shape=False, trainable=False)
# I build a new shape and assign it to x                                                                                                                       
concat = tf.concat([x, [0]], 0)
assign_op = tf.assign(x, concat, validate_shape=False)

with tf.control_dependencies([assign_op]):
  # I print x after the assignment                                                                                                                         
  # Note that the Print call is on "x" and NOT "assign_op"                                                                                                 
  new_x = x.read_value()
  print_op_dep = tf.Print(new_x, data=[new_x], message="print_op_dep:")

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(3):
    sess.run(print_op_dep)

The output is:

I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0]
I tensorflow/core/kernels/logging_ops.cc:79] print_op_dep:[0 0 0]

@mrry
Copy link
Contributor

mrry commented Feb 23, 2017

Of course, these semantics are not very intuitive, and there's no way you'd have guessed that from the documentation. @alextp is working on a new version of variables that will have more sensible semantics. I'll let him comment on how things will look in the brave new world.

@mrry mrry added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Feb 23, 2017
@mrry mrry assigned alextp and unassigned mrry Feb 23, 2017
@morgangiraud
Copy link
Author

👍
Thanks a lot for this fast answer. I understand now and i've been testing this successfully.

Also, one last question on my side:
Is using tf.assign in this way (especially the fact that i change the shape) leads to poor optimization on TensorFlow and so, poor performance ?

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Feb 23, 2017
@mrry
Copy link
Contributor

mrry commented Feb 23, 2017

Glad to hear it!

As for the performance impact, it's hard to say. At the level of individual assign calls, TensorFlow doesn't do much to optimize your code, so you aren't necessarily missing any optimizations. Concatenating and copying like you do in that code snippet will have quadratic time complexity, but I'm not sure if you're going to be doing that in such a tight loop that it matters :). (If you find yourself concatenating dynamic lists of tensors a lot, you might be interested in tf.TensorArray instead... it was introduced in part to avoid doing quadratic concatenation when accumulating loop state in a tf.while_loop().)

It is possible that having varying-shape variables will lead to e.g. more unknowns in shape inference, which could inhibit some nice optimizations that are possible when the shape of a tensor is static, but I assume you have a reason for wanting to change the shape of a variable, so some amount of dynamism is probably necessary.

@morgangiraud
Copy link
Author

All right, thanks for those insights!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:support Support issues
Projects
None yet
Development

No branches or pull requests

5 participants