0
点赞
收藏
分享

微信扫一扫

基于AlexNet网络的猫十二分类

1.项目简介

1.数据集

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

2.数据预处理

首先,定义一个prepare_image函数,取出文本文件中的图片路径与标签,并且打乱顺序

def prepare_image(file_path):
X_train = []
y_train = []

with open(file_path) as f:
context = f.readlines()
random.shuffle(context)

for str in context:
str = str.strip('\n').split('\t')

X_train.append('./cat_12/' + str[0])
y_train.append(str[1])

return X_train, y_train

再定义一个preprocess_image进行图片归一化操作,将像素值限制在0-1之间。

# 数据归一化
def preprocess_image(image):
image = tf.io.read_file(image)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_with_
举报

相关推荐

0 条评论