关于tensorflow的tf.cond依赖

先来看一个tensorflow官方文档的例子:

1
2
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

z is needed for at least one branch of the cond, the tf.mul operation is always executed unconditionally. Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics. (https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/cond)

再来看一个stackoverflow上的例子

https://stackoverflow.com/questions/37063952/confused-by-the-behavior-of-tf-cond

1
2
3
4
5
6
7
8
9
10
pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
with tf.control_dependencies([assign_x_2]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval())

不论pred设为True还是False,都得到y=[2]

稍微修改一下:

1
2
3
4
5
6
7
8
9
10
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
with tf.control_dependencies([tf.assign(x, [2])]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval(feed_dict={pred: False})) # ==> [1]
print(y.eval(feed_dict={pred: True})) # ==> [2]

注意tf.assign(x, [2])被放到函数update_x_2内部,就能得到正确结果了

理解

tf.cond的两个函数fn_true和fn_false,在执行时创建的graph,看成tf.cond的两个内部graph,tf.cond会根据pred选择执行其中一个;但如果依赖tf.cond之前的op(看成tf.cond的外部graph),则该不论pred选择哪一个分支,都会执行此op。

所以,关键看tf.cond依赖的op在外部还是内部,如果在外部,不论pred如何选择,都会执行,如果在内部,则根据pred只执行其中一个