tensorflow tf.where 代码实例

阅读 84

2022-12-24


import tensorflow as tf

judge_list = [True, False]
input_list = tf.constant([1.0, -1.0])
result_list = tf.where(judge_list,
input_list,
tf.ones(shape=[2]) * 2)

sess = tf.Session()
print(sess.run(result_list))

print结果:
​​​[1. 2.]​

import tensorflow as tf

judge_list = [True, False]
input_list = tf.constant([1.0, -1.0])
result_list = tf.where(judge_list,
input_list,
input_list * 2)

sess = tf.Session()
print(sess.run(result_list))

print结果:
​​​[ 1. -2.]​


精彩评论(0)

0 0 举报