0
点赞
收藏
分享

微信扫一扫

计数排序算法

深夜瞎琢磨 2024-08-22 阅读 44

inference_detector(model, imgs)

源代码

def inference_detector(model, imgs):
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
           Either image files or loaded images.

    Returns:
        If imgs is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
    """

    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

    cfg = model.cfg
    device = next(model.parameters()).device  # model device

    if isinstance(imgs[0], np.ndarray):
        cfg = cfg.copy()
        # set loading pipeline type
        cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'

    cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
    test_pipeline = Compose(cfg.data.test.pipeline)

    datas = []
    for img in imgs:
        # prepare data
        if isinstance(img, np.ndarray):
            # directly add img
            data = dict(img=img)
        else:
            # add information into dict
            data = dict(img_info=dict(filename=img), img_prefix=None)
        # build the data pipeline
        data = test_pipeline(data)
        datas.append(data)

    data = collate(datas, samples_per_gpu=len(imgs))
    # just get the actual data from DataContainer
    data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
    data['img'] = [img.data[0] for img in data['img']]
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]
    else:
        for m in model.modules():
            assert not isinstance(
                m, RoIPool
            ), 'CPU inference with RoIPool is not supported currently.'

    # forward the model
    with torch.no_grad():
        results = model(return_loss=False, rescale=True, **data)

    if not is_batch:
        return results[0]
    else:
        return results

分析

这个函数 inference_detector 主要用于使用已加载的检测模型对一张或多张图像进行推理检测。以下是对该函数的详细分析:

参数

  • model (nn.Module): 这是一个已经加载的检测模型,基于 PyTorch 的 nn.Module。
  • imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): 输入可以是图片文件路径或已加载的图片数据,支持单张图片或多张图片列表/元组。

返回值
如果输入是列表或元组,函数将返回与输入相同长度的结果列表;如果输入是单个图片,直接返回该图片的检测结果。
函数流程
1. 批量检测判断:

  • 首先判断输入 imgs 是否为列表或元组,以决定是否进行批量处理。
  • 如果不是,将单个图片转换为列表,便于后续处理。

2. 配置和设备设置:

  • 从模型中获取配置(model.cfg)和模型所在的设备(例如 GPU)

3. 数据预处理管道配置:

  • 如果输入图像是 numpy 数组格式,复制配置并修改加载管道类型为 LoadImageFromWebcam。
  • 更新图像到张量的转换管道 replace_ImageToTensor。
  • 使用 Compose 构建最终的数据处理管道。

4. 数据处理:

  • 遍历所有图片,根据图片格式(路径或 numpy 数组)添加到数据字典中。
  • 通过数据管道处理图像数据,生成最终用于模型输入的数据格式。

5. 数据整合:

  • 使用 collate 函数整合所有处理后的数据,根据 GPU 上的图片数量进行配置。
  • 提取必要的元数据和图像数据。

6. GPU 分配:

  • 如果模型参数在 CUDA 上,将数据分散到指定的 GPU。
  • 如果在 CPU 上运行,确保不使用 RoIPool,因为目前不支持在 CPU 上使用 RoIPool。

7. 模型推理:

  • 在不计算梯度的条件下,使用模型进行前向传播,获取检测结果。
    8. 结果返回:
  • 根据输入是单张图片还是多张图片批处理,返回相应的结果。
举报

相关推荐

0 条评论