在 WWDC 2018 蘋果推出了 Create ML ,讓開發者可以輕鬆的建立並訓練適合自己需求的機器學習模型,它支援圖片、自然語言、表格資料的學習。最近這幾天我就想到要訓練一個自家的模型,用來檢查使用者是否打算上傳不恰當的圖片。

訓練自己的模型

首先要準備好訓練資料,將資料分類放在不同的資料夾,不同的資料夾名稱就是所謂的 Label。舉例來說我會建立一個 Unsafe 資料夾存放腥羶色的圖片,再建立一個 Safe 資料夾存放正常的圖片。這裡有一些要注意的地方:

  • 每個資料夾裡頭的檔案個數至少要有 10 個。
  • 每個資料夾裡頭的檔案個數不要相差太多。
  • 每個資料夾裡頭的檔案個數越多越好。

再來我們要建立一個新的 playground,選擇 macOS -> Blank template(注意:要選 macOS 不是 iOS),然後把內容改成以下程式碼:

import CreateMLUI

let builder = MLImageClassifierBuilder()
builder.showInLiveView()

在 Xcode 裡頭切換到 Assistant Editor 並執行 playground,就能在 Assistant Editor 看到 Live View 了。把我們事先準備好的訓練資料拉進 Live View,它就會開始建立並訓練模型。

訓練好之後,我們可以把一些測試資料拉進 Live View,看看這個模型的判斷是否準確,如果滿意的話就可以把模型存起來了。

利用模型判斷資料

假設我們把模型存成 ImageClassifier.mlmodel,接下來就是把它拉進 Xcode Project 裡頭。再來我們要建立一個 ImageDetector class,負責把圖片餵給模型,並回傳模型判斷的結果。這個 class 只有一個 method,讓使用者傳一張圖片進來判斷,判斷成功的話會回應信心值(0~1),失敗的話回應一個 Error。

@interface ImageDetector : NSObject

- (void)checkImage:(UIImage *)image withSuccess:(void (^)(float confidence))success failure:(void (^)(NSError *error))failure;

@end

實作也很簡單,首先要 import 必要的檔案,ImageClassifier.h 是我們把 ImageClassifer.mlmodel 拉進 project 的時候自動產生的:

#import "ImageDetector.h"
#import "ImageClassifier.h"

@import CoreML;
@import Vision;

@interface ImageDetector ()
@property (nonatomic, strong) VNCoreMLModel *model;
@end

我們在 init 的時候載入模型:

- (instancetype)init {
    if (self = [super init]) {
        ImageClassifier *classifier = [[ImageClassifier alloc] init];
        _model = [VNCoreMLModel modelForMLModel:classifier.model error:NULL];
    }
    return self;
}

然後實作唯一的 public method:

- (void)checkImage:(UIImage *)image withSuccess:(void (^)(float confidence))success failure:(void (^)(NSError *error))failure {
    VNImageRequestHandler *handler = nil;
    if (image.CGImage) {
        handler = [[VNImageRequestHandler alloc] initWithCGImage:image.CGImage options:@{}];
    } else if (image.CIImage) {
        handler = [[VNImageRequestHandler alloc] initWithCIImage:image.CIImage options:@{}];
    }
    [self checkWithRequestHandler:handler success:success failure:failure];
}

最後實作唯一的 private method:

- (void)checkWithRequestHandler:(nullable VNImageRequestHandler *)handler success:(void (^)(float confidence))success failure:(void (^)(NSError *error))failure {
    if (!handler) {
        NSError *error = [NSError errorWithDomain:@"com.imageDetector" code:0 userInfo:nil];
        failure(error);
        return;
    }

    // 建立一個 request,用來判斷圖片並處理判斷的結果
    VNCoreMLRequest *req = [[VNCoreMLRequest alloc] initWithModel:self.model completionHandler:^(VNRequest * _Nonnull request, NSError * _Nullable error) {
        if (error) {
            failure(error);
            return;
        }

        NSString *label = @"UnSafe"; // 這個就是你要的 Label,也就是訓練資料的資料夾名稱
        NSArray<VNClassificationObservation *> *observations = request.results;
        for (VNClassificationObservation *observation in observations) {
            if (![observation.identifier isEqualToString:label]) {
                continue;
            }
            success(observation.confidence);
            return;
        }

        NSError *e = [NSError errorWithDomain:@"com.imageDetector" code:0 userInfo:nil];
        failure(e);
    }];
    req.preferBackgroundProcessing = YES;

    // 丟到 background queue 執行,才不會卡住 UI
    dispatch_async(dispatch_get_global_queue(QOS_CLASS_UTILITY, 0), ^{
        NSError *error = nil;
        if (![handler performRequests:@[req] error:&error]) {
            failure(error);
        }
    });
}

要注意的是,呼叫者如果要在 success / failure callback 處理畫面更新,記得要切回 main queue

參考資料

至此,一個為你量身定制的機器學習模型就可以正常運作了,而且 iOS 11 也有支援喔。如果有興趣的話,也可以接著看看更多的參考資料。