How to update multiple tensors using a single value with tf.scan

I assume that you have a set of Tensors that you want to update with a sequence iteratively. E.g. you have a neural network that you’d like to update with a point at time t in a sequence and values from the network at time t-1. If you want to see this in full-fledged use, look at my jupyter notebook where I recreate the Variational Recurrent Neural Network!

fn should follow the form fn(parameter_that_changes,parameter_you_change_with). This means that you can assume that your input from elem will always go to parameter_you_change_with, and that what you return should be parameter_that_changes.
Writing it like a function looks something like the following

def fn(x, elem):
return new_x

where new_x will be x the next time fn is called. That took me some time to figure out.

# updating 3 tensors with a single sequencea1=tf.Variable([0,0])a2=tf.Variable([1,1])a3=tf.Variable([2,2])sequence=tf.Variable([1,2,3])# using tf.multiply istead of '*', e.g. tf.multiply(x,2) instead of 2*x was key to this compiling...defreplace_one(old,x):a1,a2,a3=olda1=tf.add(a1,tf.multiply(x,1))a2=tf.add(a2,tf.multiply(x,2))a3=tf.add(a3,tf.multiply(x,3))return[a1,a2,a3]# key things that worked: initializer needed to match output. # dumb mistake I can see tripping up many peopleupdate=tf.scan(replace_one,sequence,initializer=[a1,a2,a3])a1=a1.assign(a2)withtf.Session()assess:sess.run(tf.global_variables_initializer())print(sess.run(update))

A few notes

So this ws more difficult to implement than I expected. I had to get all the ingredients perfectly right.

While I can assign outside of scan, for some reason the tensors a1, a2, a3 couldn’t be assigned, i.e. a1.assign(tf.add(a1,tf.multiply(x,1))), inside of scan

You can have all your values inside a single tensor for the initializer and update them via indexing. This also doesn’t work. i.e. with T=tf.concat([a1,a2,a3]), you can’t do T[0]=x

I spent a long time trying to manually concatonate the values so that I could track them in the future only to learn that scan does this by default!! E.g., for a1, the corresponding output vector is [a1+1, a1+1+2, a1+1+2+3] since the elements were [1,2,3].