【零基础成为 AI 解梦大师秘笈】Day28 - 周易解梦之人工智慧(9)

人工智慧9

前言

系列文章简介

大家好,我们是 AI . FREE Team - 人工智慧自由团队,这一次的铁人赛,自由团队将从0到1 手把手教各位读者学会 (1)Python基础语法 (2)Python Web 网页开发框架 – Django (3)Python网页爬虫 – 周易解梦网 (4)Tensorflow AI语言模型基础与训练 – LSTM (5)实际部属AI解梦模型到Web框架上。

为什麽技术要从零开始写起

自由团队的成立宗旨为开发AI/新科技的学习资源,提供各领域的学习者能够跨域学习资料科学,并透过自主学习发展协杠职涯,结合智能应用到各式领域,无论是文、法、商、管、医领域的朋友,都可以自由的学习AI技术。

资源

AI . FREE Team 读者专属福利 → Python Basics 免费学习资源

实作 part2

这次深度学习框架我们采用keras(比较快XD)

请自行延续上一篇前处理的程序码

导入套件

from keras.models import Model, Input
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional,GRU

model

model = tf.keras.Sequential([
    Embedding(len(words), 64),
    Bidirectional(LSTM(64)),
    Dense(64, activation='relu'),
    Dense(5,activation = 'softmax')
])

看一下模型的架构

model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_7 (Embedding)      (None, None, 64)          21440     
_________________________________________________________________
bidirectional_7 (Bidirection (None, 128)               66048     
_________________________________________________________________
dense_11 (Dense)             (None, 64)                8256      
_________________________________________________________________
dense_12 (Dense)             (None, 5)                 325       
=================================================================
Total params: 96,069
Trainable params: 96,069
Non-trainable params: 0
_________________________________________________________________

优化器与损失函数

model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

categorical_crossentropy是一种针对multi-class or multi-label的loss_function

training

这里记得要将资料都转为numpy的形式,其他的参数大家都可以乱调看看

batch_size是一次丢进去的数量

epochs是训练完整资料集的次数

verbose为是否启用进度条

如果你有验证及,也可以用validation_data = (np.array(X_valid_int), np.array(y_oh_valid))

history = model.fit(np.array(X_train_int), np.array(y_oh_train), batch_size=4, epochs=20, verbose=1
                    , validation_data = (np.array(X_valid_int), np.array(y_oh_valid)),
                    )
Epoch 1/20
46/46 [==============================] - 1s 28ms/step - loss: 1.5894 - accuracy: 0.2732 - val_loss: 1.5550 - val_accuracy: 0.3929
Epoch 2/20
46/46 [==============================] - 0s 10ms/step - loss: 1.5250 - accuracy: 0.3880 - val_loss: 1.4884 - val_accuracy: 0.3750
Epoch 3/20
46/46 [==============================] - 0s 11ms/step - loss: 1.3422 - accuracy: 0.3825 - val_loss: 1.2045 - val_accuracy: 0.5179
Epoch 4/20
46/46 [==============================] - 0s 11ms/step - loss: 0.8867 - accuracy: 0.6557 - val_loss: 1.2385 - val_accuracy: 0.6429
Epoch 5/20
46/46 [==============================] - 0s 11ms/step - loss: 0.4955 - accuracy: 0.7978 - val_loss: 1.6667 - val_accuracy: 0.6786
Epoch 6/20
46/46 [==============================] - 0s 11ms/step - loss: 0.3452 - accuracy: 0.8361 - val_loss: 1.4240 - val_accuracy: 0.7500
Epoch 7/20
46/46 [==============================] - 0s 11ms/step - loss: 0.2393 - accuracy: 0.9016 - val_loss: 1.6240 - val_accuracy: 0.7679
Epoch 8/20
46/46 [==============================] - 0s 11ms/step - loss: 0.1503 - accuracy: 0.9454 - val_loss: 2.0856 - val_accuracy: 0.7679
Epoch 9/20
46/46 [==============================] - 1s 11ms/step - loss: 0.1088 - accuracy: 0.9508 - val_loss: 2.0742 - val_accuracy: 0.7321
Epoch 10/20
46/46 [==============================] - 0s 10ms/step - loss: 0.0508 - accuracy: 0.9891 - val_loss: 2.0070 - val_accuracy: 0.7857
Epoch 11/20
46/46 [==============================] - 0s 11ms/step - loss: 0.0334 - accuracy: 0.9945 - val_loss: 2.1783 - val_accuracy: 0.7857
Epoch 12/20
46/46 [==============================] - 0s 11ms/step - loss: 0.0079 - accuracy: 1.0000 - val_loss: 2.3317 - val_accuracy: 0.7857
Epoch 13/20
46/46 [==============================] - 1s 11ms/step - loss: 0.0046 - accuracy: 1.0000 - val_loss: 2.3419 - val_accuracy: 0.7857
Epoch 14/20
46/46 [==============================] - 0s 11ms/step - loss: 0.0023 - accuracy: 1.0000 - val_loss: 2.4288 - val_accuracy: 0.7857
Epoch 15/20
46/46 [==============================] - 0s 11ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 2.4866 - val_accuracy: 0.7857
Epoch 16/20
46/46 [==============================] - 0s 11ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 2.5414 - val_accuracy: 0.7857
Epoch 17/20
46/46 [==============================] - 1s 11ms/step - loss: 9.3662e-04 - accuracy: 1.0000 - val_loss: 2.6041 - val_accuracy: 0.7857
Epoch 18/20
46/46 [==============================] - 0s 11ms/step - loss: 8.3006e-04 - accuracy: 1.0000 - val_loss: 2.6372 - val_accuracy: 0.7857
Epoch 19/20
46/46 [==============================] - 1s 11ms/step - loss: 6.8749e-04 - accuracy: 1.0000 - val_loss: 2.6845 - val_accuracy: 0.7857
Epoch 20/20
46/46 [==============================] - 0s 11ms/step - loss: 5.7524e-04 - accuracy: 1.0000 - val_loss: 2.7143 - val_accuracy: 0.7857

视觉化成效

我们使用matplotlib的折线图plt.plot

hist = pd.DataFrame(history.history)
plt.figure(figsize=(12,12))
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
plt.show()

测试

idx = 5
test_pred = model.predict(np.array(X_valid_int)[idx])
anwser = max(sum(test_pred))
for index,i in enumerate(sum(test_pred)):
  if i == anwser:
    print(f'sentence : {[idx2word[k] for k in X_valid_int[idx]]}')
    print(f'predict : {label_to_emoji(index)}')
sentence : ['he', 'is', 'a', 'good', 'friend', 'pad', 'pad', 'pad', 'pad', 'pad']
predict : ❤️

大家可以随机调用idx来测试不同的句子

想更深入认识 AI . FREE Team ?

自由团队 官方网站:https://aifreeblog.herokuapp.com/
自由团队 Github:https://github.com/AI-FREE-Team/
自由团队 粉丝专页:https://www.facebook.com/AI.Free.Team/
自由团队 IG:https://www.instagram.com/aifreeteam/
自由团队 Youtube:https://www.youtube.com/channel/UCjw6Kuw3kwM_il39NTBJVTg/

文章同步发布於:自由团队部落格
(想看更多文章?学习更多AI知识?敬请锁定自由团队的频道!)


<<:  Day 29 - ROS 树莓派光达履带小车实作 (3)

>>:  Mongoose Schema TimeZone

D17 - 吃一颗 Class 语法糖 (上)

前言 在 ES6 後,新增了 class 类别,一个更简洁的语法来建立物件,也是建立继承的语法糖。 ...

django新手村6-----HTTP Status Code

常见的 200 ok 404 找不到请求的网页 403 服务器拒绝请求 301 永久移动网页,重新导...

利用 Google App Script 将资料存到 Google Sheet(2)

延续昨天的内容,今天我们要完成写入&读取的功能 将信件内容写入到 Google Sheet ...

想要爬个资料也困难重重

这边先说一下,关於上一篇的程序码好像有些问题,我这次找了其他资料练习,先用了另一组程序抓取,确认抓取...

[Part 7 ] Vue.js 的精随-元件生命周期 (续)

摧毁阶段 这个阶段负责元件的移除,适合用来移除所有的事件监听以及任何会造成记忆体泄漏(memory ...