from tensorflow.python.keras.utils import losses_utils
kl = tf.keras.losses.KLDivergence(
reduction = losses_utils.ReductionV2.NONE,
name = 'kullback_leibler_divergence')
kl_loss = tf.reduce_mean(kl(logit1, logit2))
tensorflow 1.15 KL loss 代码
阅读 54
2022-07-27
from tensorflow.python.keras.utils import losses_utils
kl = tf.keras.losses.KLDivergence(
reduction = losses_utils.ReductionV2.NONE,
name = 'kullback_leibler_divergence')
kl_loss = tf.reduce_mean(kl(logit1, logit2))
相关推荐
精彩评论(0)