0
点赞
收藏
分享

微信扫一扫

Tensorflow,使用tf.where编辑tensor的每个值


import tensorflow as tf

judge_list1 = [True, True, False, False]
judge_list2 = [tf.constant(True), tf.constant(True), tf.constant(False), tf.constant(False)]

input_tensor1 = [1, 2, 3, 4]
input_tensor2 = [tf.constant(1), tf.constant(2), tf.constant(3), tf.constant(4)]

result1 = tf.where(judge_list1,
input_tensor1,
[100, 200, 300, 400])

result2 = tf.where(judge_list2,
input_tensor2,
[100, 200, 300, 400])

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(result1))
print(sess.run(result2))

print结果:
​​​[ 1 2 300 400]​​​​[ 1 2 300 400]​


举报

相关推荐

0 条评论