今天延续之前的主题,我们将使用EfficientNetB0的架构,但不使用预训练权重,参考了Keras文档的文章,我们将input_shape设定为EfficientNetB0预设的值224X224,等等会利用Keras提供的函数,帮我们载入并缩放这些图片,由於验证集的部分不希望使用资料增强,所以分成了两个资料夹,使用两个ImageDataGenerator去处理。
input_shape = (224, 224)
batch_size = 64
from keras.preprocessing.image import ImageDataGenerator
traindatagen = ImageDataGenerator(
width_shift_range = 0.1,
height_shift_range = 0.1,
zoom_range = 0.2,
shear_range = 0.1,
rotation_range = 25,
horizontal_flip = True,
rescale = 1/255.
)
validdatagen = ImageDataGenerator(
rescale = 1/255.
)
train = traindatagen.flow_from_directory(
img_directory + 'train',
target_size=input_shape,
color_mode="rgb",
class_mode="binary",
batch_size=batch_size,
shuffle=True,
interpolation="lanczos",
)
valid = validdatagen.flow_from_directory(
img_directory + 'validation',
target_size=input_shape,
color_mode="rgb",
class_mode="binary",
batch_size=batch_size,
shuffle=True,
interpolation="lanczos",
)
你应该可以看到它印出了两行结果,表示完成了对路径的扫瞄,数量比例也大致相符。接着引入我们的主角。
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model
model = EfficientNetB0(include_top=True, weights=None, input_shape=(*input_shape,3), classes=1, activation='sigmoid', pooling='avg')
model.compile(
optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy']
)
再来设定一些回调函数。
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, TerminateOnNaN, EarlyStopping
mcp = ModelCheckpoint(filepath='EfficientNetB0-{epoch:02d}.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
log = CSVLogger(filename='EfficientNetB0.csv', separator=',', append=False)
ton = TerminateOnNaN()
esl = EarlyStopping(monitor='val_loss', patience=10, mode='auto', restore_best_weights=True)
esa = EarlyStopping(monitor='val_accuracy', patience=10, mode='auto', restore_best_weights=True)
只要经过漫长的等待就可以收获胜利的果实了。
hist = model.fit(
x = train,
steps_per_epoch = train.samples // batch_size,
epochs = 50,
validation_data = valid,
validation_steps = valid.samples // batch_size,
callbacks = [mcp, log, ton, esl, esa]
)
<<: DAY27 深度学习-卷积神经网路-Yolo v2 (一)
>>: [ Day 29 | Essay ] 作梦也会梦到内心最深刻的恐惧
那路由器及虚拟机都安装好後,我们要来异地组网啦! 在此之前,我们先来介绍一下吧 什麽是 SD-WAN...
今天天内容为灯光、粒子效果的基本介绍! Duration 粒子发射的时间 Looping 设定粒子是...
在学习Encryption 跟Decryption前~ ASCII电脑编码系统是必须要知道的。 AS...
今天来讲解比较简单又很长出的题目 odd sum 先点选CPE颗星广场 再点选右边的一星 用ctrl...
(一)政策性(第一阶文件) 说明ISMS目标、方向及执行原则。 文件:资安政策、资安组织 ISMS-...