0
点赞
收藏
分享

微信扫一扫

tensorflow tf.where 代码实例


import tensorflow as tf

judge_list = [True, False]
input_list = tf.constant([1.0, -1.0])
result_list = tf.where(judge_list,
input_list,
tf.ones(shape=[2]) * 2)

sess = tf.Session()
print(sess.run(result_list))

print结果:
​​​[1. 2.]​

import tensorflow as tf

judge_list = [True, False]
input_list = tf.constant([1.0, -1.0])
result_list = tf.where(judge_list,
input_list,
input_list * 2)

sess = tf.Session()
print(sess.run(result_list))

print结果:
​​​[ 1. -2.]​


举报

相关推荐

0 条评论