一、下载并解析数据
本人已将cifar10数据集上传至网盘请自提
链接:https://pan.baidu.com/s/1w6R5_p-XzpNyjiatSEtsPQ 提取码:jqfi
使用pickle模块解压并重新随机划分训练集和测试集
import pickle
import os
import numpy as np
filepath = os.path.join(os.path.dirname(__file__), 'cifar-10-python/cifar-10-batches-py/')
filelist = [filepath + filename for filename in os.listdir(filepath)]
datas = [pickle.load(open(f, 'rb+'), encoding='bytes') for f in filelist]
all_data = []
all_label = []
for data in datas[1:]:
img = data[b'data']
label = data[b'labels']
all_data.append(img)
all_label.append(label)
data = np.vstack(all_data)
labels = np.hstack(all_label)
data = data.reshape((-1, 3, 32, 32))
data = data.transpose((0, 2, 3, 1))
# 随机打乱数据
count = data.shape[0]
p = np.random.permutation(count)
data = data[p]
labels = labels[p]
# 数据划分数据集 训练集占0.8,验证集占0.2,如需更改改动split_rate 即可
split_rate = 0.8
split_count = int(split_rate*count)
train_data = data[:split_count]
train_label = labels[:split_count]
train = zip(train_data, train_label)
test_data = data[split_count:]
test_label = labels[split_count:]
test = zip(test_data, test_label)
将数据分割成batch训练
class DatasetGenerator():
def __init__(self, datas, shuffle, batch_size):
"""
:param datas: 数据集,格式为 data,label
:param shuffle: 是否随机打乱数据 True or False
:param batch_size: 一批数据大小
"""
self._shuffle = shuffle
self._batch_size = batch_size
self._indicator = 0
all_data = []
all_label = []
for data, label in datas:
all_data.append(data)
all_label.append(label)
self._data = np.array(all_data)
self._labels = np.array(all_label)
self.count = self._data.shape[0]
def __iter__(self):
return self
def __next__(self):
return self._next_batch()
def _shuffle_data(self):
p = np.random.permutation(self.count)
self._data = self._data[p]
self._labels = self._labels[p]
def _next_batch(self):
end_indicator = self._indicator + self._batch_size
if end_indicator > self.count:
if self._shuffle:
self._shuffle_data()
self._indicator = 0
end_indicator = self._batch_size
else:
self._indicator = 0
end_indicator = self._batch_size
if end_indicator > self.count:
raise StopIteration
batch_data = self._data[self._indicator: end_indicator] / 255.0 # 归一化 0-1
batch_labels = self._labels[self._indicator: end_indicator]
self._indicator = end_indicator
return batch_data, batch_labels
BATCH_SIZE = 32
trianGenerator = DatasetGenerator(train, True, BATCH_SIZE )
testGenerator = DatasetGenerator(test, False, BATCH_SIZE )
查看数据集
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def display(train_images, train_labels):
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.show()
for data in trianGenerator:
display(*data)
搭建神经网络
使用tensorflow 搭建vgg16网络
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
def VGG16(input_shape):
inputs = Input(shape=input_shape)
conv1_1 = Conv2D(64, 3, strides=(1,1), padding='same', activation="relu")(inputs)
conv1_2 = Conv2D(64, 3, strides=(1,1), padding='same', activation="relu")(conv1_1)
pool1 = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid')(conv1_2)
conv2_1 = Conv2D(128, 3, strides=(1,1), padding='same', activation="relu")(pool1)
conv2_2 = Conv2D(128, 3, strides=(1,1), padding='same', activation="relu")(conv2_1)
pool2 = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid')(conv2_2)
conv3_1 = Conv2D(256, 3, strides=(1,1), padding='same', activation="relu")(pool2)
conv3_2 = Conv2D(256, 3, strides=(1,1), padding='same', activation="relu")(conv3_1)
conv3_3 = Conv2D(256, 3, strides=(1,1), padding='same', activation="relu")(conv3_2)
pool3 = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid')(conv3_3)
conv4_1 = Conv2D(512, 3, strides=(1,1), padding='same', activation="relu")(pool3)
conv4_2 = Conv2D(512, 3, strides=(1,1), padding='same', activation="relu")(conv4_1)
conv4_3 = Conv2D(512, 3, strides=(1,1), padding='same', activation="relu")(conv4_2)
pool4 = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid')(conv4_3)
conv5_1 = Conv2D(512, 3, strides=(1,1), padding='same', activation="relu")(pool4)
conv5_2 = Conv2D(512, 3, strides=(1,1), padding='same', activation="relu")(conv5_1)
conv5_3 = Conv2D(512, 3, strides=(1,1), padding='same', activation="relu")(conv5_2)
pool5 = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid')(conv5_3)
fc = Flatten()(pool5)
fc6 = Dense(4096, activation="relu")(fc)
fc6 = Dropout(0.5)(fc6)
fc7 = Dense(4096, activation="relu")(fc6)
fc7 = Dropout(0.5)(fc7)
fc8 = Dense(1000, activation="relu")(fc7)
fc8 = Dropout(0.5)(fc8)
out = Dense(10, activation="softmax")(fc8)
model = Model(inputs, out)
return model
model = VGG16(input_shape=[input_size, input_size, 3])
model.summary()
定义损失函数以及优化器
learning_rate = 0.0001
model = VGG16(input_shape=[input_size, input_size, 3])
model.summary()
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
保存模型
import tensorflow as tf
model_filepath = 'model/'
checkpoint_filepath = model_filepath + 'tmp/'
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_best_only=True,
save_weights_only=True,
monitor='accuracy',
mode='max'
)
开始炼丹
开始训练模型
history = model.fit(
trianGenerator,
epochs=100,
steps_per_epoch=trianGenerator.count // BATCH_SIZE + 1,
validation_data=testGenerator,
validation_steps=testGenerator.count // BATCH_SIZE + 1,
callbacks=[cp_callback]
)
model.load_weights(checkpoint_filepath)
model.save(model_filepath + 'model', save_format='tf')
可视化训练过程
import matplotlib.pyplot as plt
with open(model_filepath + '/history', 'wb') as file_pi:
pickle.dump(history.history, file_pi)
plt.subplot(211)
plt.title('Cross Entropy Loss')
plt.plot(history.history['loss'], color='blue', label='train')
plt.plot(history.history['val_loss'], color='orange', label='test')
plt.subplot(212)
plt.title('accuracy')
plt.plot(history.history['accuracy'], color='blue', label='train')
plt.plot(history.history['val_accuracy'], color='orange', label='test')
plt.show()
加载模型预测
import tensorflow as tf
import cv2 as cv
import os
import numpy as np
import matplotlib.pyplot as plt
img = cv.imread("test.jpg")
org = cv.resize(img, (32, 32))
img = np.reshape(org, (-1, 32, 32, 3)) / 255.0
model_filepath = os.path.join(os.path.dirname(__file__),"model/model")
model = tf.keras.models.load_model(model_filepath)
model.summary()
predict = model.predict(img)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plt.imshow(org)
plt.xlabel(class_names[np.argmax(model.predict(img))])
plt.show()
结果如下:
分享结束欢迎大家评判指正