TensorFlow 构造对角线为1的其余全0矩阵

IT影子

关注

阅读 58

2022-07-27


import tensorflow as tf

batch_size = 4

a = tf.one_hot(tf.range(batch_size), batch_size)

sess = tf.Session()

print(sess.run(a))

print结果:

[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]


相关推荐

精彩评论(0)

0 0 举报