YOLOV8自訂水果模型

水果模型訓練,比 Dollar bill detection 多了圖片的標示,train/valid 的分割。

安裝套件

請記得安裝 ultralytics 時,會自動安裝無法使用 GPU 的 torch 版本,所以需先安裝 torch GPU 版,再安裝 ultralytics。

pip install torch torchvision torchaudio --index-url  https://download.pytorch.org/whl/cu118 --no-cache-dir
pip install pip install ultralytics labelimg

圖片下載

請下載本站的 fruits.zip,儲存在專案下的 fruits 目錄,然後解壓縮至此,就會在 fruits 下產生 data 目錄

標示圖片

請在命令提示字元視窗輸入 pip install labelimg,然後直接執行 labelimg 即可開始標識圖片。標示完記得將 .txt 檔存放在 fruits/data/labels 之下

pip install labelimg

切割 train/valid

使用下面代碼,將 fruits/data 裏的資料,分割成 80% 存放在 datasets/train,20% 存放 datasets/valid。執行完後,將 fruits 下的 train 及 valid 二個目錄 copy 到專案下的 datasets 目錄。

import os
import random
import shutil
data_path='./data'
train_path='./train'
valid_path='./valid'
if os.path.exists(train_path):
    shutil.rmtree(train_path)
if os.path.exists(valid_path):
    shutil.rmtree(valid_path)
os.makedirs(os.path.join(train_path, 'images'))
os.makedirs(os.path.join(train_path, 'labels'))
os.makedirs(os.path.join(valid_path, 'images'))
os.makedirs(os.path.join(valid_path, 'labels'))

files=[os.path.splitext(file)[0]
       for file in os.listdir(os.path.join(data_path, "images"))]
random.shuffle(files)
mid=int(len(files)*0.8)
for file in files[:mid]:
    source=os.path.join(data_path, "images", f'{file}.jpg')
    target=os.path.join(train_path,"images", f'{file}.jpg')
    print(source, target)
    shutil.copy(source, target)

    source=os.path.join(data_path, "labels", f'{file}.txt')
    target=os.path.join(train_path,"labels", f'{file}.txt')
    print(source, target)
    shutil.copy(source, target)

for file in files[mid:]:
    source=os.path.join(data_path, "images", f'{file}.jpg')
    target=os.path.join(valid_path,"images", f'{file}.jpg')
    print(source, target)
    shutil.copy(source, target)

    source=os.path.join(data_path, "labels", f'{file}.txt')
    target=os.path.join(valid_path,"labels", f'{file}.txt')
    print(source, target)
    shutil.copy(source, target)

設定 data.yaml

在專案根目錄下新增 data.yaml,內容如下。請注意一定要寫絕對路徑

train: e:/python_ai/yolov8_fruit/train/images
val: e:/python_ai/yolov8_fruit/valid/images

nc: 4
names: ['guava', 'lemon', 'pitaya', 'wax']

下載預訓練模型

到 https://github.com/ultralytics/ultralytics 網站,往下拉下載 YOLOv8s.pt ,儲存在專案根目錄下。

yolo.exe 訓練權重

在 Terminal 執行如下指令,epochs 設定為 200 次。不過 V8很聰明,在 epochs 約 98 次後,發現無法再逼近,就會自動停止

yolo task=detect mode=train model=./yolov8x.pt data=./data.yaml epochs=200 imgsz=640

本例圖片並不多,RTX 3080Ti 訓練時間約 6 分鐘。無法訓練的人,請下載 best.pt,然後儲存在 ./runs/detect/train/weights 之下。

使用 Python 訓練模型

在專案下新增 train.py 檔,由 YOLO 戴入預訓練模型 yolov8x.pt 產生 model 物件,再由 model.train() 即可開始訓練。

請注意這段程式碼一定要寫在 if __name__ 的區塊中,否則會出現需使用 fork 執行子行程的錯誤。

import os
import shutil
import time

from ultralytics import YOLO
#訓練模型時,一定要放在 __name__ 區塊內
#否則會出現需使用 fork來執行子行程的錯誤
if __name__=='__main__':
    train_path="./runs/detect/train"
    if os.path.exists(train_path):
        shutil.rmtree(train_path)
    model = YOLO("yolov8x.pt")#因圖片數少,所以使用v8x比較準
    print("開始訓練 .........")
    t1=time.time()
    model.train(data="./data.yaml", epochs=200, imgsz=640)
    t2=time.time()
    print(f'訓練花費時間 : {t2-t1}秒')
    path=model.export()
    print(f'模型匯出路徑 : {path}')

偵測圖片

請使用如下代碼偵測圖片

import os
import platform
import pylab as plt
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from ultralytics import YOLO

def text(img, text, xy=(0, 0), color=(0, 0, 0), size=12):
    pil = Image.fromarray(img)
    s = platform.system()
    if s == "Linux":
        font =ImageFont.truetype('/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', size)
    elif s == "Darwin":
        font = ImageFont.truetype('....', size)
    else:
        font = ImageFont.truetype('simsun.ttc', size)
    ImageDraw.Draw(pil).text(xy, text, font=font, fill=color)
    return np.asarray(pil)

model=YOLO('./runs/detect/train/weights/best.pt')
img_path="./valid/images"
plt.figure(figsize=(12,9))
for i,file in enumerate(os.listdir(img_path)[-6:]):
    full=os.path.join(img_path, file)
    img=cv2.imdecode(
np.fromfile(full, dtype=np.uint8),
cv2.IMREAD_UNCHANGED
)[:,:,::-1].copy() results=model.predict(img, save=False) boxes=results[0].boxes.xyxy names = [results[0].names[int(idx.cpu().numpy())] for idx in results[0].boxes.cls] for box, name in zip(boxes, names): print(name) box=box.cpu().numpy() x1 = int(box[0]);y1 = int(box[1]);x2 = int(box[2]);y2 = int(box[3]) print(img.shape, x1, y1, x2, y2) img=cv2.rectangle(img,(x1, y1), (x2, y2), (0,255,0) , 2) img=text(img, name, (x1, y1), (0,255,0),25) plt.subplot(2,3,i+1) plt.axis("off") plt.imshow(img) plt.show()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *