水果模型訓練,比 Dollar bill detection 多了圖片的標示,train/valid 的分割。
安裝套件
請記得安裝 ultralytics 時,會自動安裝無法使用 GPU 的 torch 版本,所以需先安裝 torch GPU 版,再安裝 ultralytics。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
pip install pip install ultralytics labelimg
圖片下載
請下載本站的 fruits.zip,儲存在專案下的 fruits 目錄,然後解壓縮至此,就會在 fruits 下產生 data 目錄
標示圖片
請在命令提示字元視窗輸入 pip install labelimg,然後直接執行 labelimg 即可開始標識圖片。標示完記得將 .txt 檔存放在 fruits/data/labels 之下
pip install labelimg
切割 train/valid
使用下面代碼,將 fruits/data 裏的資料,分割成 80% 存放在 datasets/train,20% 存放 datasets/valid。執行完後,將 fruits 下的 train 及 valid 二個目錄 copy 到專案下的 datasets 目錄。
import os import random import shutil data_path='./data' train_path='./train' valid_path='./valid' if os.path.exists(train_path): shutil.rmtree(train_path) if os.path.exists(valid_path): shutil.rmtree(valid_path) os.makedirs(os.path.join(train_path, 'images')) os.makedirs(os.path.join(train_path, 'labels')) os.makedirs(os.path.join(valid_path, 'images')) os.makedirs(os.path.join(valid_path, 'labels')) files=[os.path.splitext(file)[0] for file in os.listdir(os.path.join(data_path, "images"))] random.shuffle(files) mid=int(len(files)*0.8) for file in files[:mid]: source=os.path.join(data_path, "images", f'{file}.jpg') target=os.path.join(train_path,"images", f'{file}.jpg') print(source, target) shutil.copy(source, target) source=os.path.join(data_path, "labels", f'{file}.txt') target=os.path.join(train_path,"labels", f'{file}.txt') print(source, target) shutil.copy(source, target) for file in files[mid:]: source=os.path.join(data_path, "images", f'{file}.jpg') target=os.path.join(valid_path,"images", f'{file}.jpg') print(source, target) shutil.copy(source, target) source=os.path.join(data_path, "labels", f'{file}.txt') target=os.path.join(valid_path,"labels", f'{file}.txt') print(source, target) shutil.copy(source, target)
設定 data.yaml
在專案根目錄下新增 data.yaml,內容如下。請注意一定要寫絕對路徑
train: e:/python_ai/yolov8_fruit/train/images val: e:/python_ai/yolov8_fruit/valid/images nc: 4 names: ['guava', 'lemon', 'pitaya', 'wax']
下載預訓練模型
到 https://github.com/ultralytics/ultralytics 網站,往下拉下載 YOLOv8s.pt ,儲存在專案根目錄下。
yolo.exe 訓練權重
在 Terminal 執行如下指令,epochs 設定為 200 次。不過 V8很聰明,在 epochs 約 98 次後,發現無法再逼近,就會自動停止
yolo task=detect mode=train model=./yolov8x.pt data=./data.yaml epochs=200 imgsz=640
本例圖片並不多,RTX 3080Ti 訓練時間約 6 分鐘。無法訓練的人,請下載 best.pt,然後儲存在 ./runs/detect/train/weights 之下。
使用 Python 訓練模型
在專案下新增 train.py 檔,由 YOLO 戴入預訓練模型 yolov8x.pt 產生 model 物件,再由 model.train() 即可開始訓練。
請注意這段程式碼一定要寫在 if __name__ 的區塊中,否則會出現需使用 fork 執行子行程的錯誤。
import os import shutil import time from ultralytics import YOLO #訓練模型時,一定要放在 __name__ 區塊內 #否則會出現需使用 fork來執行子行程的錯誤 if __name__=='__main__': train_path="./runs/detect/train" if os.path.exists(train_path): shutil.rmtree(train_path) model = YOLO("yolov8x.pt")#因圖片數少,所以使用v8x比較準 print("開始訓練 .........") t1=time.time() model.train(data="./data.yaml", epochs=200, imgsz=640) t2=time.time() print(f'訓練花費時間 : {t2-t1}秒') path=model.export() print(f'模型匯出路徑 : {path}')
偵測圖片
請使用如下代碼偵測圖片
import os import platform import pylab as plt import cv2 import numpy as np from PIL import Image, ImageFont, ImageDraw from ultralytics import YOLO def text(img, text, xy=(0, 0), color=(0, 0, 0), size=12): pil = Image.fromarray(img) s = platform.system() if s == "Linux": font =ImageFont.truetype('/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', size) elif s == "Darwin": font = ImageFont.truetype('....', size) else: font = ImageFont.truetype('simsun.ttc', size) ImageDraw.Draw(pil).text(xy, text, font=font, fill=color) return np.asarray(pil) model=YOLO('./runs/detect/train/weights/best.pt') img_path="./valid/images" plt.figure(figsize=(12,9)) for i,file in enumerate(os.listdir(img_path)[-6:]): full=os.path.join(img_path, file) img=cv2.imdecode(
np.fromfile(full, dtype=np.uint8),
cv2.IMREAD_UNCHANGED
)[:,:,::-1].copy() results=model.predict(img, save=False) boxes=results[0].boxes.xyxy names = [results[0].names[int(idx.cpu().numpy())] for idx in results[0].boxes.cls] for box, name in zip(boxes, names): print(name) box=box.cpu().numpy() x1 = int(box[0]);y1 = int(box[1]);x2 = int(box[2]);y2 = int(box[3]) print(img.shape, x1, y1, x2, y2) try: img=cv2.rectangle(img,(x1, y1), (x2, y2), (0,255,0) , 2) img=text(img, name, (x1, y1), (0,255,0),25) except: pass plt.subplot(2,3,i+1) plt.axis("off") plt.imshow(img) plt.show()