自编码器重建 Fashion_mnist 数据集

鱼满舱

关注

阅读 75

2023-01-12


自编码器

from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, layers
import numpy as np
from matplotlib import pyplot as plt

加载数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(buffer_size=512).batch(512)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(512)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

构建网络

class AutoEncoder(keras.Model):

def __init__(self):
super(AutoEncoder, self).__init__()

# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(20)
])

# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])

# 前向计算
def call(self, inputs, training=None):
# [b, 784] => [b, 10]
h = self.encoder(inputs)
# [b, 10] => [b, 784]
x_hat = self.decoder(h)

return x_hat

网络训练

def save_images(imgs, name):
new_im = Image.new('L', (280, 280))

index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1

new_im.save(name)

model = AutoEncoder()
model.build(input_shape=(None, 28 * 28))
model.summary()

optimizer = tf.optimizers.Adam(lr=1e-3)

Model: "auto_encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

开始训练

for epoch in range(20):

for step, x in enumerate(train_db):

#[b, 28, 28] => [b, 784]
x = tf.reshape(x, [-1, 28 * 28])
# 构建梯度记录器
with tf.GradientTape() as tape:
# 前向计算
x_rec_logits = model(x)
# 计算损失函数
rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
rec_loss = tf.reduce_mean(rec_loss)
# 自动求导
grads = tape.gradient(rec_loss, model.trainable_variables)
# 更新网络
optimizer.apply_gradients(zip(grads, model.trainable_variables))

# 打印训练误差
print("epoch: ", epoch, "loss: ", float(rec_loss))


# 从测试集采集图片
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 讲输出值转化为像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢复原始数据格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])

# [b, 28, 28] => [2b, 28, 28]
# 输入的前 50 张+重建的前 50 张图片合并
x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
# 恢复为 0-255 的范围
x_concat = x_concat.numpy() * 255.
# 转换为整型
x_concat = x_concat.astype(np.uint8)
save_images(x_concat, 'ae_images/mnist_%d.png'%epoch)

epoch:  0 loss:  0.1876431256532669
epoch: 1 loss: 0.14163847267627716
epoch: 2 loss: 0.12352141737937927
epoch: 3 loss: 0.11942803859710693
epoch: 4 loss: 0.11525192111730576
epoch: 5 loss: 0.10021436214447021
epoch: 6 loss: 0.10526927560567856
epoch: 7 loss: 0.10288294404745102
epoch: 8 loss: 0.10139968246221542
epoch: 9 loss: 0.10215207189321518
epoch: 10 loss: 0.0961870551109314
epoch: 11 loss: 0.091026671230793
epoch: 12 loss: 0.09655070304870605
epoch: 13 loss: 0.09417414665222168
epoch: 14 loss: 0.08978977054357529
epoch: 15 loss: 0.08931374549865723
epoch: 16 loss: 0.08951258659362793
epoch: 17 loss: 0.08937102556228638
epoch: 18 loss: 0.09456444531679153
epoch: 19 loss: 0.08556753396987915

def printImage(images):
plt.figure(figsize=(10, 10))
for i in range(20):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(images[i], cmap=plt.cm.binary)

x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 讲输出值转化为像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢复原始数据格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])

# [b, 28, 28] => [2b, 28, 28]
# 输入的前 50 张+重建的前 50 张图片合并
x_concat = tf.concat([x[:10], x_hat[:10]], axis=0)
# 恢复为 0-255 的范围
x_concat = x_concat.numpy() * 255.
# 转换为整型
x_concat = x_concat.astype(np.uint8)
printImage(x_concat)

  • 上面 5 行是原始图片, 下面 5 行是 重建后的图片

保存本地的图片:

第一次 epoch

左边 5 列是原图片,右边 5 列是经过重建后的。可以看到,此时还不是很清楚

自编码器重建 Fashion_mnist 数据集_cv

第十次 epoch

自编码器重建 Fashion_mnist 数据集_深度学习_02


第二十次 epoch

自编码器重建 Fashion_mnist 数据集_神经网络_03


精彩评论(0)

0 0 举报