tensorflow 1.15 KL loss 代码

阅读 59

2022-07-27


from tensorflow.python.keras.utils import losses_utils

kl = tf.keras.losses.KLDivergence(
reduction = losses_utils.ReductionV2.NONE,
name = 'kullback_leibler_divergence')

kl_loss = tf.reduce_mean(kl(logit1, logit2))


相关推荐

精彩评论(0)

0 0 举报