0
点赞
收藏
分享

微信扫一扫

VOC格式或者COCO格式检测数据集提取特定类

序言

有时候我们需要从已经标记好的数据集中提取某些类进行训练,以常见的COCO数据集和VOC数据集格式的标注为例,本文提供了两种数据集格式的特定类提取方法,网上也有很多类似的内容,权当总结记录,以后用到时方便找出。

一、COCO格式数据集提取特定类

# COCO数据集提取某个类或者某些类

from pycocotools.coco import COCO
import os
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw

# 需要设置的路径
savepath = "D:\BaiduNetdiskDownload\coco\COCO/car/"
img_dir = savepath + 'images/'
anno_dir = savepath + 'annotations/'
datasets_list = ['train2017']

# coco有80类,这里写要提取类的名字,以car为例
classes_names = ['car','bus','truck']
# 包含所有类别的原coco数据集路径
'''
目录格式如下:
$COCO_PATH
----|annotations
----|train2017
----|val2017
----|test2017
'''
dataDir = 'D:\BaiduNetdiskDownload\coco\COCO/'

headstr = """\
<annotation>
<folder>VOC</folder>
<filename>%s</filename>
<source>
<database>My Database</database>
<annotation>COCO</annotation>
<image>flickr</image>
<flickrid>NULL</flickrid>
</source>
<owner>
<flickrid>NULL</flickrid>
<name>company</name>
</owner>
<size>
<width>%d</width>
<height>%d</height>
<depth>%d</depth>
</size>
<segmented>0</segmented>
"""
objstr = """\
<object>
<name>%s</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>%d</xmin>
<ymin>%d</ymin>
<xmax>%d</xmax>
<ymax>%d</ymax>
</bndbox>
</object>
"""

tailstr = '''\
</annotation>
'''


# 检查目录是否存在,如果存在,先删除再创建,否则,直接创建
def mkr(path):
if not os.path.exists(path):
os.makedirs(path) # 可以创建多级目录


def id2name(coco):
classes = dict()
for cls in coco.dataset['categories']:
classes[cls['id']] = cls['name']
return classes


def write_xml(anno_path, head, objs, tail):
f = open(anno_path, "w")
f.write(head)
for obj in objs:
f.write(objstr % (obj[0], obj[1], obj[2], obj[3], obj[4]))
f.write(tail)


def save_annotations_and_imgs(coco, dataset, filename, objs):
# 将图片转为xml,例:COCO_train2017_000000196610.jpg-->COCO_train2017_000000196610.xml
dst_anno_dir = os.path.join(anno_dir, dataset)
mkr(dst_anno_dir)
anno_path = dst_anno_dir + '/' + filename[:-3] + 'xml'
img_path = dataDir + dataset + '/' + filename
print("img_path: ", img_path)
dst_img_dir = os.path.join(img_dir, dataset)
mkr(dst_img_dir)
dst_imgpath = dst_img_dir + '/' + filename
print("dst_imgpath: ", dst_imgpath)
img = cv2.imread(img_path)
# if (img.shape[2] == 1):
# print(filename + " not a RGB image")
# return
shutil.copy(img_path, dst_imgpath)

head = headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
tail = tailstr
write_xml(anno_path, head, objs, tail)


def showimg(coco, dataset, img, classes, cls_id, show=True):
global dataDir
I = Image.open('%s/%s/%s' % (dataDir, dataset, img['file_name']))
# 通过id,得到注释的信息
annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
# print(annIds)
anns = coco.loadAnns(annIds)
# print(anns)
# coco.showAnns(anns)
objs = []
for ann in anns:
class_name = classes[ann['category_id']]
if class_name in classes_names:
# print(class_name)
if 'bbox' in ann:
bbox = ann['bbox']
xmin = int(bbox[0])
ymin = int(bbox[1])
xmax = int(bbox[2] + bbox[0])
ymax = int(bbox[3] + bbox[1])
obj = [class_name, xmin, ymin, xmax, ymax]
objs.append(obj)
draw = ImageDraw.Draw(I)
draw.rectangle([xmin, ymin, xmax, ymax])
if show:
plt.figure()
plt.axis('off')
plt.imshow(I)
plt.show()

return objs


for dataset in datasets_list:
# ./COCO/annotations/instances_train2017.json
annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataset)

# 使用COCO API用来初始化注释数据
coco = COCO(annFile)

# 获取COCO数据集中的所有类别
classes = id2name(coco)
# print(classes)
# [1, 2, 3, 4, 6, 8]
classes_ids = coco.getCatIds(catNms=classes_names)
# print(classes_ids)
for cls in classes_names:
# 获取该类的id
cls_id = coco.getCatIds(catNms=[cls])
img_ids = coco.getImgIds(catIds=cls_id)
# print(cls, len(img_ids))
# imgIds=img_ids[0:10]
for imgId in tqdm(img_ids):
img = coco.loadImgs(imgId)[0]
filename = img['file_name']
# print(filename)
objs = showimg(coco, dataset, img, classes, classes_ids, show=False)
# print(objs)
save_annotations_and_imgs(coco, dataset, filename, objs)

二、VOC格式数据集提取特定类

# VOC数据集提取某个类或者某些类
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import xml.etree.ElementTree as ET
import shutil

# 根据自己的情况修改相应的路径
ann_filepath = r'Annotations/'
img_filepath = r'JPEGImages/'
img_savepath = r'imgs/'
ann_savepath = r'xmls/'
if not os.path.exists(img_savepath):
os.mkdir(img_savepath)

if not os.path.exists(ann_savepath):
os.mkdir(ann_savepath)

# 这是VOC数据集中所有类别
# classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
# 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
# 'dog', 'horse', 'motorbike', 'pottedplant',
# 'sheep', 'sofa', 'train', 'person','tvmonitor']

classes = ['car'] # 这里是需要提取的类别

def save_annotation(file):
tree = ET.parse(ann_filepath + '/' + file)
root = tree.getroot()
result = root.findall("object")
bool_num = 0
for obj in result:
if obj.find("name").text not in classes:
root.remove(obj)
else:
bool_num = 1
if bool_num:
tree.write(ann_savepath + file)
return True
else:
return False

def save_images(file):
name_img = img_filepath + os.path.splitext(file)[0] + ".png"
shutil.copy(name_img, img_savepath)
# 文本文件名自己定义,主要用于生成相应的训练或测试的txt文件
with open('train.txt', 'a') as file_txt:
file_txt.write(os.path.splitext(file)[0])
file_txt.write("\n")
return True


if __name__ == '__main__':
for f in os.listdir(ann_filepath):
print(f)
if save_annotation(f):
save_images(f)

三、VOC格式数据集修改某个类的名字

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import os
import xml.etree.ElementTree as ET

origin_ann_dir = r'xmls/' # 设置原始标签路径为 Annos
new_ann_dir = r'xmls/' # 设置新标签路径 Annotations
for dirpaths, dirnames, filenames in os.walk(origin_ann_dir): # os.walk游走遍历目录名
for filename in filenames:
print("process...")
if os.path.isfile(r'%s%s' % (origin_ann_dir, filename)): # 获取原始xml文件绝对路径,isfile()检测是否为文件 isdir检测是否为目录
origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename)) # 如果是,获取绝对路径(重复代码)
new_ann_path = os.path.join(r'%s%s' % (new_ann_dir, filename))
tree = ET.parse(origin_ann_path) # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
root = tree.getroot() # 获取根节点
for object in root.findall('object'): # 找到根节点下所有“object”节点
name = str(object.find('name').text) # 找到object节点下name子节点的值(字符串)
# 功能1.删除指定类别的标签。如果name等于str,则删除该节点
# if (name in ["car_head"]):
# root.remove(object)

# 功能2.修改指定类别的标签。如果name等于str,则修改name
if (name in ["car","bus","truck"]): # 将car bus truck三个类改成car类
object.find('name').text = "car"

# # 功能3.删除labelmap中没有的标签。检查是否存在labelmap中没有的类别
# for object in root.findall('object'):
# name = str(object.find('name').text)
# if not (name in ["chepai","chedeng","chebiao","person"]):
# print(filename + "------------->label is error--->" + name)

# # 功能4.比对xml中filename名称与图片名称是否一致。如果xml中filename名称与文件名称不一致,则对其进行修改
# get_name = str(root.find('filename').text)
# if filename.replace(".xml", ".jpg") != get_name:
# print("{}-->name is inconformity!".format(filename))
# root.find('filename').text = filename.replace(".xml", ".jpg")
# else:
# continue

tree.write(new_ann_path) # tree为文件,write写入新的文件中。


举报

相关推荐

0 条评论