AI ninja project [day 12] 图片分类(2)

这一篇,我想再参考官网的攻略写一篇,
不过内容多增加了一些程序上的处理,以及过拟合(Overfitting)时的处理。

参考页面:https://www.tensorflow.org/tutorials/images/classification?hl=zh_tw

首先,引入模组:

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

下载花朵照片资料集,有五类花朵,总共3670张图片,放置於五个资料夹,
我们可以印出下载路径,以及查看图片数目:

import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
print(data_dir)

image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

切分训练集以及测试集,validation_split为切分测试集的比例,
而seed为必须给而且需要为一样的参数(负责洗牌)

batch_size = 32
img_height = 180
img_width = 180


train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=456,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=456,
  image_size=(img_height, img_width),
  batch_size=batch_size)  

可以查看标签的内容有那些花:

class_names = train_ds.class_names
print(class_names)

https://ithelp.ithome.com.tw/upload/images/20210912/20122678fqIbJJokSE.png

缓存资料(可以给路径cache("/path/to/file")),增加训练速度,
并且将前处理加入pipline:

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

建立cnn模型,
可以发现第一层,有先进行前处理,将相素数值都除以255,以进行标准化(机器不用算很巨大的数值)

num_classes = 5

model = Sequential([
  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs设定为10,进行训练

epochs=10
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)       

查看训练过程:

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

https://ithelp.ithome.com.tw/upload/images/20210912/20122678sjZVFO4uto.png

可以发现训练集有随着时间增加准确度,
但测试集的准确度卡在0.65就不在提升了。
测试集的损失函数反而随着时间增加了,
代表有过拟合(Overfitting)的现象。

处理方式

我们可以使用扭曲、翻转、歪斜的方式(假设要辨识前方禁止通行交通号志,就不适合使用这招)来增加训练集资料。

data_augmentation = keras.Sequential(
  [
    layers.experimental.preprocessing.RandomFlip("horizontal", 
                                                 input_shape=(img_height, 
                                                              img_width,
                                                              3)),
    layers.experimental.preprocessing.RandomRotation(0.1),
    layers.experimental.preprocessing.RandomZoom(0.1),
  ]
)

另一种方法为我们在模型中,加一层Dropout来调节权重:

layers.Dropout(0.2)

重新建立新的模型:

model = Sequential([
  data_augmentation,
  layers.experimental.preprocessing.Rescaling(1./255),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Dropout(0.2),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

进行训练:

epochs = 15
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

查看训练结果:

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

https://ithelp.ithome.com.tw/upload/images/20210912/20122678mv018Jn0yH.png

我们也可以使用官网提供的照片来进行预测,
可以发现由於模型一开始吃资料的时候有多了batch这个张量,所以用tf.expand_dims来增加维度:

sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)

img = keras.preprocessing.image.load_img(
    sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batch

predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])

print(
    "This image most likely belongs to {} with a {:.2f} percent confidence."
    .format(class_names[np.argmax(score)], 100 * np.max(score))
)

https://ithelp.ithome.com.tw/upload/images/20210912/20122678q32rXgCk4O.png


<<:  [Day11] Tableau 轻松学 - Workbook/Worksheet/Dashboard/Story

>>:  [CSS] Flex/Grid Layout Modules, part 7

辅助魔法强化AWS上的服务扩大范围

辅助魔法 今日会把架构上的剩下服务讲完。 NACL这边使用预设的,就不用在YAML特别撰写。 Rou...

连续 30 天 玩玩看 ProtoPie - Day 12

第二种启动相机的方法 昨天勾选 Camera 的 Auto Start 来启动相机。 今天使用第二种...

DAY16-JAVA的继承(3)

改写 改写(overriding)的观念和多载相似,他们都是JAVA的多型(polymorphism...

Genero Package 个别套件与板差简介

FGL,Genero核心语言套件 Genero 1.X 版本时,仅有单薄的核心语法编译、运行功能。画...

{DAY8} SQLite基础语法

前言 今天要开始练习SQLite基本的语法 介绍内容有 SELECT 从资料库中选取特定资料 数值...