inference_detector(model, imgs)
源代码
def inference_detector(model, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
if isinstance(imgs, (list, tuple)):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = model.cfg
device = next(model.parameters()).device # model device
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
datas = []
for img in imgs:
# prepare data
if isinstance(img, np.ndarray):
# directly add img
data = dict(img=img)
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
data = test_pipeline(data)
datas.append(data)
data = collate(datas, samples_per_gpu=len(imgs))
# just get the actual data from DataContainer
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
data['img'] = [img.data[0] for img in data['img']]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# forward the model
with torch.no_grad():
results = model(return_loss=False, rescale=True, **data)
if not is_batch:
return results[0]
else:
return results
分析
这个函数 inference_detector 主要用于使用已加载的检测模型对一张或多张图像进行推理检测。以下是对该函数的详细分析:
参数
- model (nn.Module): 这是一个已经加载的检测模型,基于 PyTorch 的 nn.Module。
- imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): 输入可以是图片文件路径或已加载的图片数据,支持单张图片或多张图片列表/元组。
返回值
如果输入是列表或元组,函数将返回与输入相同长度的结果列表;如果输入是单个图片,直接返回该图片的检测结果。
函数流程
1. 批量检测判断:
- 首先判断输入 imgs 是否为列表或元组,以决定是否进行批量处理。
- 如果不是,将单个图片转换为列表,便于后续处理。
2. 配置和设备设置:
- 从模型中获取配置(model.cfg)和模型所在的设备(例如 GPU)
3. 数据预处理管道配置:
- 如果输入图像是 numpy 数组格式,复制配置并修改加载管道类型为 LoadImageFromWebcam。
- 更新图像到张量的转换管道 replace_ImageToTensor。
- 使用 Compose 构建最终的数据处理管道。
4. 数据处理:
- 遍历所有图片,根据图片格式(路径或 numpy 数组)添加到数据字典中。
- 通过数据管道处理图像数据,生成最终用于模型输入的数据格式。
5. 数据整合:
- 使用 collate 函数整合所有处理后的数据,根据 GPU 上的图片数量进行配置。
- 提取必要的元数据和图像数据。
6. GPU 分配:
- 如果模型参数在 CUDA 上,将数据分散到指定的 GPU。
- 如果在 CPU 上运行,确保不使用 RoIPool,因为目前不支持在 CPU 上使用 RoIPool。
7. 模型推理:
- 在不计算梯度的条件下,使用模型进行前向传播,获取检测结果。
8. 结果返回: - 根据输入是单张图片还是多张图片批处理,返回相应的结果。