import tensorflow as tf
judge_list = [True, False]
input_list = tf.constant([1.0, -1.0])
result_list = tf.where(judge_list,
input_list,
tf.ones(shape=[2]) * 2)
sess = tf.Session()
print(sess.run(result_list))
print结果:
[1. 2.]
import tensorflow as tf
judge_list = [True, False]
input_list = tf.constant([1.0, -1.0])
result_list = tf.where(judge_list,
input_list,
input_list * 2)
sess = tf.Session()
print(sess.run(result_list))
print结果:
[ 1. -2.]