0
点赞
收藏
分享

微信扫一扫

语义分割,去除边缘线代码


import tensorflow as tf
import scipy.misc as msc



'''
对于语义分割的边缘线,白色的为255,进行去除
'''




def remove_ignore_label(gt ,output=None ,pred=None):
''' 去除label为255的值,进行交叉熵的计算
gt: not one-hot
output: a distriution of all labels, and is scaled to macth the size of gt
NOTE the result is a flatted tensor
and all label which is bigger that or equal to self.category_num is void label
'''
gt = tf.reshape(gt ,shape=[-1]) # (180000,) 把矩阵 转化为向量

indices = tf.squeeze(tf.where(tf.less(gt, 21)) ,axis=1) #除去边缘线 判断是否小于 255
#tf.less(gt, 21) 找到所以小于21的label,相当于除去边缘线, 某位置< 21 返回True, 否则返回False
#tf.where(tf.less(gt, 21)) 返回这个位置的index,在为True的位置
#tf.squeeze 压缩为1的维度

gt = tf.gather(gt ,indices)
# 根据indices 取出这个位置的值


if output is not None:
output = tf.reshape(output, shape=[-1, 21]) #转化为21维度的特征,每个特征,相当于一张图片
output = tf.gather(output ,indices) # output 输出也是 [小于21的索引值(相当与一张图片除为255的所以值) , 21]
return gt ,output
elif pred is not None:
pred = tf.reshape(pred, shape=[-1])
pred = tf.gather(pred, indices)
return gt ,pred



label = tf.truncated_normal(shape=(3,281,500),stddev=0.1) # 输入图片的label (b, w, h)

output = tf.truncated_normal(shape=(3,281,500,21),stddev=0.1) # 网络输出图片的概率 (b, w, h, c) #21代表类别



label,output = remove_ignore_label(label,output)
label = tf.cast(label, tf.int32)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=output))

#这里进行计算交叉熵

with tf.Session() as sess:
print(sess.run(loss))

 

举报

相关推荐

0 条评论