Day 28 ~ AI从入门到放弃 - 猫狗辨识之三

今天延续之前的主题,我们将使用EfficientNetB0的架构,但不使用预训练权重,参考了Keras文档的文章,我们将input_shape设定为EfficientNetB0预设的值224X224,等等会利用Keras提供的函数,帮我们载入并缩放这些图片,由於验证集的部分不希望使用资料增强,所以分成了两个资料夹,使用两个ImageDataGenerator去处理。

input_shape = (224, 224)
batch_size = 64

from keras.preprocessing.image import ImageDataGenerator

traindatagen = ImageDataGenerator(
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    zoom_range = 0.2,
    shear_range = 0.1,
    rotation_range = 25,
    horizontal_flip = True,
    rescale = 1/255.
)

validdatagen = ImageDataGenerator(
    rescale = 1/255.
)

train = traindatagen.flow_from_directory(
    img_directory + 'train',
    target_size=input_shape,
    color_mode="rgb",
    class_mode="binary",
    batch_size=batch_size,
    shuffle=True,
    interpolation="lanczos",
)

valid = validdatagen.flow_from_directory(
    img_directory + 'validation',
    target_size=input_shape,
    color_mode="rgb",
    class_mode="binary",
    batch_size=batch_size,
    shuffle=True,
    interpolation="lanczos",
)

你应该可以看到它印出了两行结果,表示完成了对路径的扫瞄,数量比例也大致相符。接着引入我们的主角。

from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model

model = EfficientNetB0(include_top=True, weights=None, input_shape=(*input_shape,3), classes=1, activation='sigmoid', pooling='avg')

model.compile(
    optimizer = 'adam',
    loss = 'binary_crossentropy',
    metrics = ['accuracy']
)

再来设定一些回调函数。

from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, TerminateOnNaN, EarlyStopping
mcp = ModelCheckpoint(filepath='EfficientNetB0-{epoch:02d}.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
log = CSVLogger(filename='EfficientNetB0.csv', separator=',', append=False)
ton = TerminateOnNaN()
esl = EarlyStopping(monitor='val_loss', patience=10, mode='auto', restore_best_weights=True)
esa = EarlyStopping(monitor='val_accuracy', patience=10, mode='auto', restore_best_weights=True)

只要经过漫长的等待就可以收获胜利的果实了。

hist = model.fit(
    x = train,
    steps_per_epoch = train.samples // batch_size,
    epochs = 50,
    validation_data = valid,
    validation_steps = valid.samples // batch_size,
    callbacks = [mcp, log, ton, esl, esa]
)

<<:  DAY27 深度学习-卷积神经网路-Yolo v2 (一)

>>:  [ Day 29 | Essay ] 作梦也会梦到内心最深刻的恐惧

Day 9 - 利用路由协议来组 SD-WAN 网路

那路由器及虚拟机都安装好後,我们要来异地组网啦! 在此之前,我们先来介绍一下吧 什麽是 SD-WAN...

Unity与Photon的新手相遇旅途 | Day5-灯光介绍、粒子效果

今天天内容为灯光、粒子效果的基本介绍! Duration 粒子发射的时间 Looping 设定粒子是...

【C++】Encryption and Decryption

在学习Encryption 跟Decryption前~ ASCII电脑编码系统是必须要知道的。 AS...

[Day3]odd sum

今天来讲解比较简单又很长出的题目 odd sum 先点选CPE颗星广场 再点选右边的一星 用ctrl...

ISMS 程序书1~4阶着样写

(一)政策性(第一阶文件) 说明ISMS目标、方向及执行原则。 文件:资安政策、资安组织 ISMS-...