自定模型5種花卉

      在〈自定模型5種花卉〉中尚無留言

VGG19可以分辨1000個種類,但如果我們要辨識的,並不在這1000個之內,就要自已訓練。訓練的步驟為 : 
收集圖片資料,訓練,偵測。開始之前,請先安裝如下套件

pip install tensorflow==2.10.1 matplotlib opencv-python

收集圖片

收集圖片是一項大工程,還好 tensorflow 幫我們收集了一大堆圖片。請到 http://download.tensorflow.org/example_images/flower_photos.tgz 下載圖片。下載後置於專案之下並解開,就會多一個 flower_photos 目錄。

flower_photos 裏面 有 5 個資料夾,分別是 daisy(雛菊),dandelion(蒲公英),roses(玫瑰),sunflowers(向日葵),tulips(鬱金香)

Dense 層

Dense 層只會出現在全連接層,也就是輸出層。Dense 的作用就是由原本的 x 種狀況 (特徵),經過某種演算後,變成下一層的 y 種狀況。

剛剛講的某種演算,通常就是用捲積的方式來計算下一層的結果。

訓練模型

把每張圖片讀入縮小後放入 data 中,並同時由目錄名稱記錄每張圖的類別放在 labels 中。然後將VGG19前三個連接層抽出,再加入自已的全連接層。

一開始始用 GlobalAveragePooling2D 將 (長*寬*通道) 轉換成 (1*1*通道),其方法是將每個通道中的權重作平均值。如下圖所示,最前面的通道中,(1+5+4+5+6+5+3+9+4+2+5+2+8+6+8+7)/16 = 5

Dense 種類為 5 種,激活方式為 relu (線性整流, 將負值變為0),最後的輸出層激活含數為 softmax (轉換為機率,總合為1)。

在每個連接層( Dense) 之後,都需作 BatchNormalization。BN 層的作用是把一個 batch 內的所有數據,從不規則的分佈拉到常態分佈,將數字集中在平均為 0,標準差為 1 的範圍。這樣作的好處是使得數據能夠分佈在激活函數的敏感區域,敏感區域即為梯度較大的區域,因此在反向傳播時能夠較快反饋誤差傳播。BN的運作方式又是一篇論文,裏面有著複雜的數學運算,所以只需了解其功能即可。

模型中所加入的各層,不一定要按這個方式,也可以改用其它層,可自行測試看看。

訓練時要注意 batch_size,如果太大會造成顯卡記憶体不足,就需往下調整。訓練的時間依硬体等級有所不同。訓練好的模型會儲存在 flower 目錄中。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import random
import numpy as np
import cv2
from keras import Sequential, Model
from keras.applications import VGG19
from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten, BatchNormalization
from keras.optimizers import SGD, Adam
import pylab as plt
import shutil
imgs=[]
labels=[]
path="./flower_photos"
for name in os.listdir("flower_photos"):
    for flower in os.listdir(os.path.join(path, name)):
        img=cv2.imdecode(np.fromfile(os.path.join(path, name, flower), 
            dtype=np.uint8), cv2.IMREAD_COLOR)[:,:,::-1].copy()
        img=cv2.resize(img, (224,224), interpolation=cv2.INTER_LINEAR)
        imgs.append(img)
        labels.append(name)
upset=list(zip(imgs, labels))

#混淆資料並切割 10% 為 test
random.seed(1)
random.shuffle(upset)
imgs, labels=zip(*upset)
train=int(len(imgs)*0.9)
train_imgs=np.array(imgs[:train])
train_labels=np.array(labels[:train])
test_imgs=np.array(imgs[train:])
test_labels=np.array(labels[train:])

#將 label 變成 onehot
train_onehot=np.zeros([len(train_labels),5])
test_onehot=np.zeros([len(test_labels),5])

kind={'daisy':0,'dandelion':1,'roses':2,'sunflowers':3,'tulips':4}
for i in range(len(train_onehot)):
    # item=train_labels[i]#'sunflowers'
    # no=kind[item]#kind['sunflowers'], no=3
    # train_onehot[i][no] = 1  # [0 0 0 1 0]
    train_onehot[i][kind[train_labels[i]]]=1
for i in range(len(test_onehot)):
    test_onehot[i][kind[test_labels[i]]]=1

#建立模型
model_base=VGG19(weights='imagenet', include_top=False, input_shape=(224,224,3))
for layer in model_base.layers:
    layer.trainable=False

model=Sequential()
model.add(model_base)
model.add(GlobalAveragePooling2D())
model.add(Dense(256,activation='relu'))
model.add(BatchNormalization())
model.add(Dense(64,activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.2))
model.add(Dense(5, activation='softmax'))

#編譯模型
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy']
              )

#開始訓練模型
history=model.fit(
    train_imgs,
    train_onehot,
    batch_size=64,
    epochs=50,
    validation_data=(test_imgs, test_onehot)
)
if os.path.exists('./flower'):
    shutil.rmtree('./flower')
model.save("flower")
p1=plt.plot(history.history['accuracy'], label='training acc')#訓練時的準確度
p2=plt.plot(history.history['val_accuracy'], label='val acc')#測試時的準確度
p3=plt.plot(history.history['loss'], label='training loss')#訓練時的損失率
p4=plt.plot(history.history['val_loss'], label='val loss')#測試時的損失率
plt.legend()
plt.show()

辨識圖片

如果無法訓練模型,可下載人本已訓練好的模型 : flower_5_model.zip

載入模型後,再載入圖片即可辨識。記得圖片需縮小為 224*224,然後擴展為4維。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import keras
from keras.applications.vgg19 import preprocess_input
import pylab as plt
import cv2, numpy as np
model=keras.models.load_model('flower')
path="./images"
kind={0:"daisy", 1:"dandelion",2:"roses", 3:"sunflowers", 4:"tulips"}
for i, file in enumerate(os.listdir(path)):
    full=os.path.join(path, file)
    img=cv2.imdecode(np.fromfile(full, dtype=np.uint8), cv2.IMREAD_COLOR)
    img=img[:,:,::-1].copy()
    x=cv2.resize(img, (224,224), interpolation=cv2.INTER_LINEAR)
    x=np.expand_dims(x, axis=0)
    x=preprocess_input(x)
    out=model.predict(x)
    idx=out[0].argmax()
    name=kind[idx]
    ax=plt.subplot(2,5,i+1)
    ax.set_title(name)
    ax.imshow(img)
    ax.axis("off")
plt.show()

結果 : 1/1 [==============================] - 2s 2s/step [[7.7020763e-05 1.8764113e-04 9.9917346e-01 8.6867498e-05 4.7506584e-04]] roses

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *