YOLOV8自訂美金模型

YOLO v8 預設模型可辨識 COCO 資料集的 80 種物件。若要辨識其它物件,就需準備自訂資料集自行訓練

安裝套件

請記得安裝 ultralytics 時,自動安裝的 torch 無法啟動 GPU,所以請由如下指令安裝套件

pip install pip install ultralytics labelimg
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio===2.0.2+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html

下載資料集

自訂模型最麻煩的就是資料集準備,要收集大量圖片,然後標示,光這項作業就要耗費相當長的時間。還好 Roboflow 提供許多 公用資料集。請由公用資料集的 Roboflow 100 進入,然後搜尋 “Dollar bill detection”,或由如下網址下載 https://universe.roboflow.com/currency-7mtqe/dollarbilldetection

按下 Download this Dataset,再選擇 v8 版本,將 Dollar Bill Detection.v24-raw-images.yolov8.zip 存放在專案根目錄下。

目錄安排

將 .zip 檔解壓縮檔案,就會產生 “Dollar Bill Detection.v24-raw-images.yolov8” 目錄,請改名為 “Dollar”。

datasets 設定

datasets 又稱為訓練資料集。也就是包含了 train/valid/test 三個目錄,每個目錄又有 images/labels 二個目錄。專案在安裝 utlralytics 套件時,會在 C:\Users\登入者\AppData\Roaming\Ultralytics\settings.yaml 裏將 datasets_dir 設定為如下

settings_version: 0.0.4
datasets_dir: E:\python_ai\yolov8\datasets
weights_dir: weights

如果新專案(b)又安裝一次的話,那麼 datasets_dir 就會被改成 e:\python_ai\b\datasets,那麼舊專案就找不到 datasets了。所以打開 Dollar/data.yaml , 將前三行的 “../” 改成絕對路徑。

#train: ../train/images
#val: ../valid/images
#test: ../test/images
train: e:/python_ai/yolov8/Dollar/train/images val: e:/python_ai/yolov8/Dollar/valid/images test: e:/python_ai/yolov8/Dollar/test/images

下載預訓練權重

https://github.com/ultralytics/ultralytics 網站,往下拉即有預訓練模型可供下載。預訓練模型有yolov8nyolov8syolov8myolov8lyolov8x 共 5 種,愈往下愈精準但速度愈慢。請將下載的 .pt 檔儲存在專案目錄下。

使用 yolo.exe 訓練模型

進入 Terminal  執行如下指令

yolo task=detect mode=train model=./yolov8n.pt data=./Dollar/data.yaml epochs=200 imgsz=640

以 RTX 3080Ti , 在 epochs 約 193 時就會自動停止。

預訓練模型使用 yolov8n.pt 大約需要 10 分鐘才能訓練完成。

訓練完成,會在 ./runs/detect/train/weights 下產生 best.pt 模型,這個就是日後要偵測的主要東西。

使用 Python 訓練模型

在專案下新增 train.py 檔,由 YOLO 戴入預訓練模型 yolov8n.pt 產生 model 物件,再由 model.train() 即可開始訓練。

請注意這段程式碼一定要寫在 if __name__ 的區塊中,否則會出現需使用 fork 執行子行程的錯誤。

import os
import shutil
import time

from ultralytics import YOLO
#訓練模型時,一定要放在 __name__ 區塊內
#否則會出現需使用 fork來執行子行程的錯誤
if __name__=='__main__':
    train_path="./runs/detect/train"
    if os.path.exists(train_path):
        shutil.rmtree(train_path)
    model = YOLO("yolov8n.pt")
    print("開始訓練 .........")
    t1=time.time()
    model.train(data="./Dollar/data.yaml", epochs=200, imgsz=640)
    t2=time.time()
    print(f'訓練花費時間 : {t2-t1}秒')
    path=model.export()
    print(f'模型匯出路徑 : {path}')

開始偵測

如下代碼可以開始偵測,並手動繪制方框

import os
import platform
import pylab as plt
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from ultralytics import YOLO

def text(img, text, xy=(0, 0), color=(0, 0, 0), size=12):
    pil = Image.fromarray(img)
    s = platform.system()
    if s == "Linux":
        font =ImageFont.truetype('/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', size)
    elif s == "Darwin":
        font = ImageFont.truetype('....', size)
    else:
        font = ImageFont.truetype('simsun.ttc', size)
    ImageDraw.Draw(pil).text(xy, text, font=font, fill=color)
    return np.array(pil)

model=YOLO('./runs/detect/train/weights/best.pt')
path="./Dollar/test/images"
files=['IMG_2055_jpg.rf.c8eb7f24a411ac5878e4752ba96fa844.jpg',
       'IMG_1935_jpg.rf.31a9a0fed380075a91755ea0c8e7a7de.jpg',
       'IMG_1975_jpg.rf.65a944ac43d8319e6b2d434f063f42ad.jpg',
       'IMG_2051_jpg.rf.610864c59657a73ac5773f528eb05865.jpg']

for i,file in enumerate(files):
    full=os.path.join(path, file)
    img=cv2.imdecode(np.fromfile(full, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
    img=img[:,:,::-1].copy()#一定要 copy,不然無法繪製矩型(原因不明)

    results=model.predict(img, save=False)
    boxes=results[0].boxes.xyxy
    names = [results[0].names[int(i.cpu().numpy())] for i in results[0].boxes.cls]
    for box, name in zip(boxes, names):
        print(name)
        box=box.cpu().numpy()
        x1 = int(box[0])
        y1 = int(box[1])
        x2 = int(box[2])
        y2 = int(box[3])
        img=cv2.rectangle(img,(x1, y1), (x2, y2), (0,255,0) , 3)
        img=text(img, name, (x1, y1), (0,0,255),200)
    plt.subplot(2,2,i+1)
    plt.axis("off")
    plt.imshow(img)
plt.show()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *