import tensorflow as tf
batch_size = 4
a = tf.one_hot(tf.range(batch_size), batch_size)
sess = tf.Session()
print(sess.run(a))
print结果:
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
TensorFlow 构造对角线为1的其余全0矩阵
阅读 57
2022-07-27
import tensorflow as tf
batch_size = 4
a = tf.one_hot(tf.range(batch_size), batch_size)
sess = tf.Session()
print(sess.run(a))
print结果:
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
相关推荐
精彩评论(0)