tensorflow asks inputs for unnecessary placeholders when using tf.cond()

Consider following code snippet that includes tensorflow tf.cond().

    import tensorflow as tf
    import numpy as np

    bb = tf.placeholder(tf.bool)
    xx = tf.placeholder(tf.float32, name='xx')
    yy = tf.placeholder(tf.float32, name='yy')

    zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

    with tf.Session() as sess:
            dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
            print(sess.run(zz, feed_dict=dict1)) # works fine without errors

            dict2 = {bb:False, yy:np.array([1., 3, 4])}
            print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
                                                 # provide an input for xx

In both cases, bb is False and evaluation of zz theoretically has no dependency on xx, but still tensorflow requires an input for xx. Even though it can be provided as a dummy array, it has to be matched with the shape of yy and is not as clean as dict2.

Can anybody suggest how to evaluate zz (using tf.cond() or any other approach) without providing a value for xx?

1 answer

  • answered 2018-02-13 02:56 Lior

    You can define xx as a tf.Variable instead, giving it a default value (which will be used whenever xx is not fed with another value). A few things to notice:

    1. Although xx is not a placeholder - you can still treat it as if it were by feeding values into it through the feed_dict.
    2. Use validate_shape=False so that you can feed any shapes into xx.
    3. Use trainable=False so that xx is not optimized over (otherwise, an optimizer might change its default value to things like Nan, which may cause problems).
    4. Don't forget to initialize the values for xx, by using, e.g., tf.global_variables_initializer().

    Here is the code:

    import tensorflow as tf
    import numpy as np
    
    bb = tf.placeholder(tf.bool)
    xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
    yy = tf.placeholder(tf.float32, name='yy')
    
    zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
    
    with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
       print(sess.run(zz, feed_dict=dict1))
       dict2 = {bb:False, yy:np.array([1., 3, 4])}
       print(sess.run(zz, feed_dict=dict2))