Hellow World
需要安装 tensorflow tensorflow-datasets matplotlib
import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
def mnist_load_1():
# 基于 tensorflow_datasets 加载
train_ds = tfds.load('mnist', split='train', shuffle_files=True)
test_ds = tfds.load('mnist', split='train', shuffle_files=True)
# 打印数据集描述信息
print(train_ds.element_spec)
# 分离 image 和 label train_image = train_ds.map(lambda example: example['image'])
train_label = test_ds.map(lambda example: example['label'])
test_image = test_ds.map(lambda example: example['image'])
test_label = test_ds.map(lambda example: example['label'])
return (train_image, train_label), (test_image, test_label)
# 查看数据集内容
def show_example_1(images_ds, labels_ds):
for image in images_ds.take(1):
print('Image shape:', image.shape)
plt.imshow(image, cmap='gray')
plt.show()
for label in labels_ds.take(1):
print('Label:', label)
def mnist_load_2():
# 基于 keras 加载
return keras.api.datasets.mnist.load_data()
# 查看数据集内容
def show_example_2(images, labels):
print('Image shape:', images.shape)
print('Label length::', len(labels))
plt.imshow(images[4], cmap='gray')
plt.show()
if __name__ == '__main__':
# 1. 加载 mnist 数据集
# 加载方法1: 通过 tensorflow_datasets
# (x_train, y_train), (x_test, y_test) = mnist_load_1()
# show_example_1(x_train, y_train)
# 加载方法2: 通过 keras
(x_train, y_train), (x_test, y_test) = mnist_load_2()
# show_example_2(x_train, y_train)
# 2. 神经网络
# 设置模型层
model = keras.Sequential([
keras.layers.Dense(512, activation='relu'),
keras.layers.Dense(10, activation='softmax'),
])
# 设置编译步骤: 优化器、损失函数、指标
model.compile(optimizer='rmsprop',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 3. 准备数据
train_images = x_train.reshape((60000, 28 * 28)).astype('float32') / 255
test_images = x_test.reshape((10000, 28 * 28)).astype('float32') / 255
# 4. 训练数据
model.fit(train_images, y_train, epochs=5, batch_size=128)
#5. 利用模型进行预测
# 取10条测试数据进行预测
predict = model.predict(test_images[0:10])
# 打印第一条预测结果
print_prefix = "== > prediction "
print(print_prefix, ": ")
print(print_prefix, "info: ", predict[0])
print(print_prefix, "number: ",predict[0].argmax())
print(print_prefix, "actual number:",y_test[0])
plt.imshow(x_test[0], cmap='gray')
plt.show()
#6. 使用测试数据评估新模型
test_loss, test_acc = model.evaluate(test_images, y_test)
print(f'Test accuracy:{test_acc * 100:.2f}%')