0
点赞
收藏
分享

微信扫一扫

demosaicnet-master的包代码阅读笔记


init.py
我在学init用法时候的笔记 该文件里面都是导入模块,其中从dataset.py是导入所有模块,因为模糊导入的__all__没有定义。

from .modules import BayerDemosaick
 from .modules import XTransDemosaick
 from .mosaic import xtrans
 from .mosaic import bayer
 from .mosaic import xtrans_cell
 from .dataset import *dataset.py文件,这个文件里面很多代码和download_dataset.py里面一样[download_dataset.py的阅读笔记]
 “”“Dataset loader for demosaicnet.”""
 import os
 import platform
 import subprocess#用来复制和删除文件的模块
 import shutil#hash值,md5值
 import hashlibimport numpy as np
#图像读写模块
 from imageio import imread
 from torch.utils.data import Dataset as TorchDataset
 import wget#作者写的一个模块好像,但是在torch的包下面有
ttools的资料ttools的阅读笔记
 import ttools#文件夹下面的包里面的模块
 from .mosaic import bayer, xtrans

python list 或者tensor 代码后最后 加了一个逗号(,)就变成了元组。为什么Python在列表和元组的末尾允许使用逗号?

__all__的用法 因为在init里面是from .dataset import *,所以导入的是all列表里面的这些东西
all = [“BAYER_MODE”, “XTRANS_MODE”, “Dataset”,
“TRAIN_SUBSET”, “VAL_SUBSET”, “TEST_SUBSET”]

#输入ttools找到它的安装路径,在pycharm中打开,用全局查找找到定义处
pycharm全局搜索方法

def get_logger(name):
    """Get a named logger.

    Args:
        name(string): name of the logger
    """
    return logging.getLogger(name)

#私有属性__name__ 就是当前模块名
LOG = ttools.get_logger(name)

def set_logger(debug=False):
    """Set the default logging level and log format.

    Args:
        debug(bool): if True, enable debug logs.
    """

    log_level = logging.INFO
    prefix = "[%(process)d] %(levelname)s %(name)s"
    suffix = " | %(message)s"
    if debug:
        log_level = logging.DEBUG
        prefix += " %(filename)s:%(lineno)s"
    if HAS_COLORED_LOGS:
        coloredlogs.install(
            level=log_level,
            format=prefix+suffix)
    else:
        logging.basicConfig(
            level=log_level,
            format=prefix+suffix)

#设置logger属性
 ttools.set_logger(True)#设置几个量
 BAYER_MODE = “bayer”
 “”“Applies a Bayer mosaic pattern.”""XTRANS_MODE = “xtrans”
 “”“Applies an X-Trans mosaic pattern.”""TRAIN_SUBSET = “train”
 “”“Loads the ‘train’ subset of the data.”""VAL_SUBSET = “val”
 “”“Loads the ‘val’ subset of the data.”""TEST_SUBSET = “test”
 “”“Loads the ‘test’ subset of the data.”""class Dataset(TorchDataset):
 “”"Dataset of challenging image patches for demosaicking.

Args:
    download(bool): if True, automatically download the dataset.
    mode(:class:`BAYER_MODE` or :class:`XTRANS_MODE`): mosaic pattern to apply to the data.
    subset(:class:`TRAIN_SUBET`, :class:`VAL_SUBSET` or :class:`TEST_SUBSET`): subset of the data to load.
"""

def __init__(self, root, download=False,
             mode=BAYER_MODE, subset="train"):

    super(Dataset, self).__init__()

    self.root = os.path.abspath(root)

    if subset not in [TRAIN_SUBSET, VAL_SUBSET, TEST_SUBSET]:
        raise ValueError("Dataset subet should be '%s', '%s' or '%s', got"
                         " %s" % (TRAIN_SUBSET, TEST_SUBSET, VAL_SUBSET,
                                  subset))

    if mode not in [BAYER_MODE, XTRANS_MODE]:
        raise ValueError("Dataset mode should be '%s' or '%s', got"
                         " %s" % (BAYER_MODE, XTRANS_MODE, mode))
    self.mode = mode

    listfile = os.path.join(self.root, subset, "filelist.txt")
    LOG.debug("Reading image list from %s", listfile)

    if not os.path.exists(listfile):
        if download:
            _download(self.root)
        else:
            LOG.error("Filelist %s not found", listfile)
            raise ValueError("Filelist %s not found" % listfile)
    else:
        LOG.debug("No need no download the data, filelist exists.")

    self.files = []
    with open(listfile, "r") as fid:
        for fname in fid.readlines():
            self.files.append(os.path.join(self.root, subset, fname.strip()))

def __len__(self):
    return len(self.files)

def __getitem__(self, idx):
    """Fetches a mosaic / demosaicked pair of images.

    Returns
        mosaic(np.array): with size [3, h, w] the mosaic data with separated color channels.
        img(np.array): with size [3, h, w] the groundtruth image.
    """
    fname = self.files[idx]
    img = np.array(imread(fname)).astype(np.float32) / (2**8-1)
    img = np.transpose(img, [2, 0, 1])

    if self.mode == BAYER_MODE:
        mosaic = bayer(img)
    else:
        mosaic = xtrans(img)

    return mosaic, img

CHECKSUMS = {
 ‘datasets.z01’: ‘da46277afe85d3a91c065e4751fb8175’,
 ‘datasets.zip’: ‘3434f60f5e9b263ef78e207b54e9debe’,
 }def _download(dst):
 dst = os.path.abspath(dst)
 files = CHECKSUMS.keys()
 fullzip = os.path.join(dst, “datasets.zip”)
 joinedzip = os.path.join(dst, “joined.zip”)

URL_ROOT = "https://data.csail.mit.edu/graphics/demosaicnet"

if not os.path.exists(joinedzip):
    LOG.info("Dowloading %d files to %s (This will take a while, and ~80GB)", len(
        files), dst)

    os.makedirs(dst, exist_ok=True)
    for f in files:
        fname = os.path.join(dst, f)
        url = os.path.join(URL_ROOT, f)

        do_download = True
        if os.path.exists(fname):
            checksum = md5sum(fname)
            if checksum == CHECKSUMS[f]:  # File is is and correct
                LOG.info('%s already downloaded, with correct checksum', f)
                do_download = False
            else:
                LOG.warning('%s checksums do not match, got %s, should be %s',
                            f, checksum, CHECKSUMS[f])
                try:
                    os.remove(fname)
                except OSError as e:
                    LOG.error("Could not delete broken part %s: %s", f, e)
                    raise ValueError

        if do_download:
            LOG.info('Downloading %s', f)
            wget.download(url, fname)

        checksum = md5sum(fname)

        if checksum == CHECKSUMS[f]:
            LOG.info("%s MD5 correct", f)
        else:
            LOG.error('%s checksums do not match, got %s, should be %s. Downloading failed',
                      f, checksum, CHECKSUMS[f])

    LOG.info("Joining zip files")
    cmd = " ".join(["zip", "-FF", fullzip, "--out", joinedzip])
    subprocess.check_call(cmd, shell=True)

    # Cleanup the parts
    for f in files:
        fname = os.path.join(dst, f)
        try:
            os.remove(fname)
        except OSError as e:
            LOG.warning("Could not delete file %s", f)

# Extract
wd = os.path.abspath(os.curdir)
os.chdir(dst)
LOG.info("Extracting files from %s", joinedzip)
cmd = " ".join(["unzip", joinedzip])
subprocess.check_call(cmd, shell=True)

try:
    os.remove(joinedzip)
except OSError as e:
    LOG.warning("Could not delete file %s", f)

LOG.info("Moving subfolders")
for k in ["train", "test", "val"]:
    shutil.move(os.path.join(dst, "images", k), os.path.join(dst, k))
images = os.path.join(dst, "images")
LOG.info("removing '%s' folder", images)
shutil.rmtree(images)

def md5sum(filename, blocksize=65536):
 hash = hashlib.md5()
 with open(filename, “rb”) as f:
 for block in iter(lambda: f.read(blocksize), b""):
 hash.update(block)
 return hash.hexdigest()


举报

相关推荐

0 条评论