tensorflow 自定义函数

阅读 86

2022-07-27


函数里也只能用 TF自己的算子 对tensor操作

import tensorflow as tf


batch_size = 4
hidden_size = 2
input_tensor = tf.random_normal([batch_size, hidden_size])

def true_function():
return 1.0

def false_function():
return 0.0

def map_function(each_value):
print(each_value) # a tensor
return tf.cond(each_value > 0.5, true_function, false_function)

input_tensor = tf.reshape(input_tensor, [-1]) # 转成1d的

input_tensor = tf.map_fn(map_function,
input_tensor)

input_tensor = tf.reshape(input_tensor, [batch_size, hidden_size])

sess = tf.Session()

print(sess.run(input_tensor))


精彩评论(0)

0 0 举报