Skip to content

Get placeholder for initial state of nested RNN #2838

@danijar

Description

@danijar
Contributor

For RNN cells, we get the initial state using cell.zero_state() and the last state after processing a sequence using rnn.dynamic_rnn(). However, to use the last state as the initial state for the next run, one must create a tf.placeholder(). As far as I know, currently there is no way to create and fill such a placeholder (or nested tuple of placeholders) automatically. Such a feature would be very useful so that we don't have to adjust the placeholder manually when changing the RNN cell.

Activity

ebrevdo

ebrevdo commented on Jun 14, 2016

@ebrevdo
Contributor

Is this request specifically for truncated BPTT? or something more general?

danijar

danijar commented on Jun 15, 2016

@danijar
ContributorAuthor

It's for both truncated BPTT and architectures using LSTM decoders. In the second case, the cells are initialized with some encoded activation. For an example see: Skip-Thought Vectors (Kiros et al. 2015).

tomrunia

tomrunia commented on Jul 10, 2016

@tomrunia

I am also interested in having a good way to remember the LSTM states for the next batch. This question was also asked by me on StackOverflow: http://stackoverflow.com/questions/38241410/tensorflow-remember-lstm-state-for-next-batch-stateful-lstm

ebrevdo

ebrevdo commented on Jul 10, 2016

@ebrevdo
Contributor

We are working on a system that makes this easy. Something should already
in the github repo within a coupe of weeks.
On Jul 10, 2016 3:34 AM, "Tom Runia" notifications@github.com wrote:

I am also interested in having a good way to remember the LSTM states for
the next batch. This question was also asked by me on StackOverflow:
http://stackoverflow.com/questions/38241410/tensorflow-remember-lstm-state-for-next-batch-stateful-lstm


You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub
#2838 (comment),
or mute the thread
https://github.com/notifications/unsubscribe/ABtimw3vtaf2mdKwtZrF-DDmMN1jgNwjks5qUMq6gaJpZM4I0pMP
.

tomrunia

tomrunia commented on Jul 12, 2016

@tomrunia

This might be of interest to @danijar : #2695
Great that you are working on things to make this easier, I will wait for an update :-)

yangzw

yangzw commented on Jul 31, 2016

@yangzw

In the example ptb_word_lm.py, line 257

  state = m.initial_state.eval()
  for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size,
                                                    m.num_steps)):
    cost, state, _ = session.run([m.cost, m.final_state, eval_op],
                                 {m.input_data: x,
                                  m.targets: y,
                                  m.initial_state: state})

However, there is no placeholder for m.ititial_state. Why this could work?

danijar

danijar commented on Jul 31, 2016

@danijar
ContributorAuthor

@yangzw You can feed in values for any tensor. Placeholders are just special in that they throw an error if you don't feed them while a variable would silently use its last value.

ebrevdo

ebrevdo commented on Aug 17, 2016

@ebrevdo
Contributor

We now have a comprehensive solution for truncated BPTT; introduced in 955efc9. See tf.contrib.training.batch_sequences_with_states. Unfortunately for now the only examples are in the unit tests.

Automatically creating nested placeholders would be useful. I'll look into adding this.

wpm

wpm commented on Aug 28, 2016

@wpm

Is there example code for the my_parser function in the example in the documentation for batch_sequences_with_states?

I'm trying to figure it out from the documentation but am still having questions.

nitishgupta

nitishgupta commented on Aug 29, 2016

@nitishgupta

@ebrevdo : In my opinion there should be something more general. For cases in which one trains a RNN with the whole sequence being fed at once but during inference requires fetching and feeding the states on a per time-step basis, the solution right now is not very neat. Is there some hope of fetching and feeding list of state tuples for MultiRNNCell in some way.

danijar

danijar commented on Aug 29, 2016

@danijar
ContributorAuthor

I'm current solving this by holding the state in a non-trainable variable that I initialize from the default state. The variable name is prefixed by state/ and I have helper functions to return a dictionary from name to tensor containing all variables matching this prefix. Similarly, I have a helper function to assign variable values from this dictionary.

This is a general way to handle context, but it's not straight forward using the existing TensorFlow features. Moreover, it doesn't work with the new decision to represent LSTM states as tuples.

I can contribute code to TensorFlow for a feature like this, but we should think this through first, and see that it matches TensorFlow's preferred way to handle states.

ebrevdo

ebrevdo commented on Aug 29, 2016

@ebrevdo
Contributor

@danijar: I recommend against using a non-trainable variable because this is not thread-safe (you can't run multiple inference threads against the same graph). However, it's not too hard to create some placeholder tensors and wrap them in the necessary tuple type. Similarly when calling a session run, one can pull out the "next state" tuple and store it, feeding it as an input to the next session.run. This is decidedly more thread-safe than the variable solution (and in fact is zero-copy if you're doing this in a C++ client; though sadly not zero-copy in python since TF runtime must copy feed_dict inputs from python since python does its own memory management)

@nitishgupta you can now fetch an arbitrary tuple type in python. don't think you can feed one though (but that may have changed recently?) since usually per-step RNN inference is meant done in a C++ client, i don't have any plans to add python sugar for this.

@wpm an example of "my_parser" is something that reads a serialized SequenceExample via a reader and deserializes it using parse_single_sequence_example. The parse_single_sequence_example call returns context and sequences dictionaries that exactly match some of the inputs of batch_sequences_with_states.

14 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    No branches or pull requests

      Participants

      @wpm@aselle@yangzw@carlthome@ebrevdo

      Issue actions

        Get placeholder for initial state of nested RNN · Issue #2838 · tensorflow/tensorflow