在 WWDC 2018 蘋果推出了 Create ML ,讓開發者可以輕鬆的建立並訓練適合自己需求的機器學習模型,它支援圖片、自然語言、表格資料的學習。最近這幾天我就想到要訓練一個自家的模型,用來檢查使用者是否打算上傳不恰當的圖片。
訓練自己的模型
首先要準備好訓練資料,將資料分類放在不同的資料夾,不同的資料夾名稱就是所謂的 Label
。舉例來說我會建立一個 Unsafe
資料夾存放腥羶色的圖片,再建立一個 Safe
資料夾存放正常的圖片。這裡有一些要注意的地方:
- 每個資料夾裡頭的檔案個數至少要有 10 個。
- 每個資料夾裡頭的檔案個數不要相差太多。
- 每個資料夾裡頭的檔案個數越多越好。
再來我們要建立一個新的 playground,選擇 macOS -> Blank
template(注意:要選 macOS
不是 iOS
),然後把內容改成以下程式碼:
import CreateMLUI
let builder = MLImageClassifierBuilder()
builder.showInLiveView()
在 Xcode 裡頭切換到 Assistant Editor
並執行 playground,就能在 Assistant Editor 看到 Live View
了。把我們事先準備好的訓練資料拉進 Live View,它就會開始建立並訓練模型。
訓練好之後,我們可以把一些測試資料拉進 Live View,看看這個模型的判斷是否準確,如果滿意的話就可以把模型存起來了。
利用模型判斷資料
假設我們把模型存成 ImageClassifier.mlmodel
,接下來就是把它拉進 Xcode Project 裡頭。再來我們要建立一個 ImageDetector
class,負責把圖片餵給模型,並回傳模型判斷的結果。這個 class 只有一個 method,讓使用者傳一張圖片進來判斷,判斷成功的話會回應信心值(0~1),失敗的話回應一個 Error。
@interface ImageDetector : NSObject
- (void)checkImage:(UIImage *)image withSuccess:(void (^)(float confidence))success failure:(void (^)(NSError *error))failure;
@end
實作也很簡單,首先要 import 必要的檔案,ImageClassifier.h
是我們把 ImageClassifer.mlmodel
拉進 project 的時候自動產生的:
#import "ImageDetector.h"
#import "ImageClassifier.h"
@import CoreML;
@import Vision;
@interface ImageDetector ()
@property (nonatomic, strong) VNCoreMLModel *model;
@end
我們在 init
的時候載入模型:
- (instancetype)init {
if (self = [super init]) {
ImageClassifier *classifier = [[ImageClassifier alloc] init];
_model = [VNCoreMLModel modelForMLModel:classifier.model error:NULL];
}
return self;
}
然後實作唯一的 public method:
- (void)checkImage:(UIImage *)image withSuccess:(void (^)(float confidence))success failure:(void (^)(NSError *error))failure {
VNImageRequestHandler *handler = nil;
if (image.CGImage) {
handler = [[VNImageRequestHandler alloc] initWithCGImage:image.CGImage options:@{}];
} else if (image.CIImage) {
handler = [[VNImageRequestHandler alloc] initWithCIImage:image.CIImage options:@{}];
}
[self checkWithRequestHandler:handler success:success failure:failure];
}
最後實作唯一的 private method:
- (void)checkWithRequestHandler:(nullable VNImageRequestHandler *)handler success:(void (^)(float confidence))success failure:(void (^)(NSError *error))failure {
if (!handler) {
NSError *error = [NSError errorWithDomain:@"com.imageDetector" code:0 userInfo:nil];
failure(error);
return;
}
// 建立一個 request,用來判斷圖片並處理判斷的結果
VNCoreMLRequest *req = [[VNCoreMLRequest alloc] initWithModel:self.model completionHandler:^(VNRequest * _Nonnull request, NSError * _Nullable error) {
if (error) {
failure(error);
return;
}
NSString *label = @"UnSafe"; // 這個就是你要的 Label,也就是訓練資料的資料夾名稱
NSArray<VNClassificationObservation *> *observations = request.results;
for (VNClassificationObservation *observation in observations) {
if (![observation.identifier isEqualToString:label]) {
continue;
}
success(observation.confidence);
return;
}
NSError *e = [NSError errorWithDomain:@"com.imageDetector" code:0 userInfo:nil];
failure(e);
}];
req.preferBackgroundProcessing = YES;
// 丟到 background queue 執行,才不會卡住 UI
dispatch_async(dispatch_get_global_queue(QOS_CLASS_UTILITY, 0), ^{
NSError *error = nil;
if (![handler performRequests:@[req] error:&error]) {
failure(error);
}
});
}
要注意的是,呼叫者如果要在 success / failure
callback 處理畫面更新,記得要切回 main queue
。
參考資料
至此,一個為你量身定制的機器學習模型就可以正常運作了,而且 iOS 11 也有支援喔。如果有興趣的話,也可以接著看看更多的參考資料。