import os
import cv2
import time
import math
import torch
import random
import matplotlib
import numpy as np
from PIL import Image
from tqdm import tqdm
from threading import Thread
from pathlib import Path
matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg')
def color_list():
"""给不同类别的框配置不同的颜色"""
def hex2rgb(h):
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()]
def xywh2xyxy(x):
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2
y[:, 1] = x[:, 1] - x[:, 3] / 2
y[:, 2] = x[:, 0] + x[:, 2] / 2
y[:, 3] = x[:, 1] + x[:, 3] / 2
return y
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def get_xywh(path):
try:
with open(path, "r") as f:
label = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
except Exception as e:
print("An error occurred while loading the file {}: {}".format(path, e))
if label.shape[0]:
assert label.shape[1] == 5, "> 5 label columns: %s" % path
assert (label >= 0).all(), "negative labels: %s" % path
assert (label[:, 1:] <= 1).all(), "non-normalized or out of bounds coordinate labels: %s" % path
return label
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
"""
@param images: 图像 ==> 格式为:[batch_size, _, h, w]
@param targets: 标签 ==> 格式:[batch_size, image_index, box_info]
@param paths: 图像名字, 列表格式
@param fname: 画图之后的保存名字,字符串路径
@param names: 标签名
@param max_size:
@param max_subplots:
@return:
"""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
if np.max(images[0]) <= 1:
images *= 255
tl = 3
tf = max(tl - 1, 1)
bs, _, h, w = images.shape
bs = min(bs, max_subplots)
ns = np.ceil(bs ** 0.5)
scale_factor = max_size / max(h, w)
if scale_factor < 1:
h = math.ceil(scale_factor * h)
w = math.ceil(scale_factor * w)
colors = color_list()
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
for i, img in enumerate(images):
if i == max_subplots:
break
block_x = int(w * (i // ns))
block_y = int(h * (i % ns))
img = img.transpose(1, 2, 0)
if scale_factor < 1:
img = cv2.resize(img, (w, h))
mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
if len(targets) > 0:
image_targets = targets[targets[:, 0] == i]
boxes = xywh2xyxy(image_targets[:, 2:6]).T
classes = image_targets[:, 1].astype('int')
labels = image_targets.shape[1] == 6
conf = None if labels else image_targets[:, 6]
if boxes.shape[1]:
if boxes.max() <= 1.01:
boxes[[0, 2]] *= w
boxes[[1, 3]] *= h
elif scale_factor < 1:
boxes *= scale_factor
boxes[[0, 2]] += block_x
boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T):
cls = int(classes[j])
color = colors[cls % len(colors)]
cls = names[cls] if names else cls
if labels or conf[j] > 0.25:
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
if paths:
label = Path(paths[i]).name[:40]
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
lineType=cv2.LINE_AA)
cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
if fname:
r = min(1280. / max(h, w) / ns, 1.0)
mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
Image.fromarray(mosaic).save(fname)
return mosaic
def draw_prepare(label, label_dir, img_dir, class_names, save_dir, img_type):
image_name = label.replace(".txt", img_type)
image_name_use = [image_name]
image_path = os.path.join(img_dir, image_name)
assert os.path.exists(image_path), "检查标签对应的图片路径是否正确?"
img = np.array(Image.open(image_path))
image = np.expand_dims(img, axis=0)
images = image.transpose((0, 3, 1, 2))
label_path = os.path.join(label_dir, label)
labels = get_xywh(label_path)
labels = np.insert(labels, 0, 0, axis=1)
save_name = os.path.join(save_dir, image_name)
plot_images(images, labels, image_name_use, save_name, class_names)
def main():
"""修改这几个变量就可以了:image_dir, labels_dir, names, img_type"""
image_dir = r'D:\YoungMaster\YOLOv7\tmp\draw_test\images'
labels_dir = r'D:\YoungMaster\YOLOv7\tmp\draw_test\labels'
image_box_save_dir = image_dir+"_box"
if not os.path.exists(image_box_save_dir):
os.makedirs(image_box_save_dir)
names = ["bird","cat", "dog"]
img_type = ".png"
label_list = os.listdir(labels_dir)
start = time.time()
for label in tqdm(label_list):
Thread(target=draw_prepare, args=(label, labels_dir, image_dir, names, image_box_save_dir, img_type)).start()
end = time.time()
time_use = end - start
print(f"用时:{time_use // 3600} hour {(time_use - time_use // 3600) // 60} min {time_use - (time_use - time_use // 3600) // 60} seconds")
print(f"\n勾画结果保存位置: {image_box_save_dir}")
if __name__ == '__main__':
main()