基于GoogLeNet的猫十二分类

科牛

关注

阅读 44

2022-05-01

1.导包

from tensorflow import keras
import tensorflow as tf
from keras.preprocessing import image
import random
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm

2.数据预处理

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

(1)定义prepare_image函数从文件中分离路径和标签

def prepare_image(file_path):
    X_train = []
    y_train = []

    with open(file_path) as f:
        context = f.readlines()
    random.shuffle(context)

    for str in context:
        str = str.strip('\n').split('\t')

        X_train.append('./image/cat_12/' + str[0])
        y_train.append(str[1])

    return X_train, y_train

(2)定义preprocess_image函数进行图像的归一化

def preprocess_image(img):
    img = image.load_img(img, target_s

精彩评论(0)

0 0 举报