Day 13 : 弱监督式标注资料 Snorkel (视觉关系侦测篇)

  • 接续 Day 12的弱监督式 Snorkel 范例,今天再花点时间示范用 Snorkel 标注影像资料。
  • Snorkel 透过简易广泛的程序撰写判断逻辑後,交由生成对抗网路产生分类结果,分类的结果效果不差,而且不用手动标注。之後也会介绍 AutoML 等工具,在此之前我们来透过 Snorkel 官方范例了解如何进行。
  • Colab 实作范例

Visual Relationship Detection, VRD 视觉关系侦测

  • VRD 说明:
    • 通常图片内容物都有物体之间的关联性,定义描述为为a subject <predictate> object 。例如, person riding bicycle , “person” 和 “bicycle” 分别是主词和受词, “riding” 是关系动词。
    • 此范例源自 snorkel-tutorials,目的为对视觉关系检测 (VRD) 数据集进行操作,专注於图片内物件之间的关系分类任务。
    • 以下图示红色框代表主题,而绿色框代表对象。该谓词(如踢)表示什麽关系连接主体和客体。

0. 设定环境

  • 笔者有调整为 Colab 可以执行的程序,不过范例主程序里执行的model.py遇到pandas.DataFrame.as_matrix()是旧的语法,也提供修正方式pandas.DataFrame.values
  • 当然您也可以如官网范例指定旧版的 pandas ,相关设定不赘述。

1. 加载数据

  • 范例将训练集、有效集和测试集加载为 DataFrame。
  • 数据集的采样版本在训练集、开发集和测试集上使用相同的 26 个数据。此设置旨在快速演示 Snorkel 如何处理此任务,而非演示性能。

2. 编写 Labeling Functions (LFs)

  • 我们现在编写标记函数来检测边界框对之间存在什麽关系。为此,我们可以将各种直觉编码到标记函数中:

  • 分类直觉:关於这些关系中通常涉及的主词和受词类别的知识(例如,person通常是谓词 RIDE 和的主词 CARRY)

  • 空间直觉:关於主词和受词的相对位置的知识(例如,主词通常高於动词的受词RIDE)

    RIDE = 0
    CARRY = 1
    OTHER = 2
    ABSTAIN = -1
    
  • 我们从编码分类直觉的标记函数开始:我们使用关於共同的主题-客体类别对的知识 RIDE,CARRY 以及关於哪些主题或客体不太可能涉及这两种关系的知识。

    from snorkel.labeling import labeling_function
    
    # Category-based LFs
    @labeling_function()
    def lf_ride_object(x):
        if x.subject_category == "person":
            if x.object_category in [
                "bike",
                "snowboard",
                "motorcycle",
                "horse",
                "bus",
                "truck",
                "elephant",
            ]:
                return RIDE
        return ABSTAIN
    
    
    @labeling_function()
    def lf_carry_object(x):
        if x.subject_category == "person":
            if x.object_category in ["bag", "surfboard", "skis"]:
                return CARRY
        return ABSTAIN
    
    
    @labeling_function()
    def lf_carry_subject(x):
        if x.object_category == "person":
            if x.subject_category in ["chair", "bike", "snowboard", "motorcycle", "horse"]:
                return CARRY
        return ABSTAIN
    
    
    @labeling_function()
    def lf_not_person(x):
        if x.subject_category != "person":
            return OTHER
        return ABSTAIN
    
  • 现在编码空间直觉,其中包括测量边界框之间的距离并比较它们的相对区域。

    YMIN = 0
    YMAX = 1
    XMIN = 2
    XMAX = 3
    
    import numpy as np
    
    # Distance-based LFs
    @labeling_function()
    def lf_ydist(x):
        if x.subject_bbox[XMAX] < x.object_bbox[XMAX]:
            return OTHER
        return ABSTAIN
    
    
    @labeling_function()
    def lf_dist(x):
        if np.linalg.norm(np.array(x.subject_bbox) - np.array(x.object_bbox)) <= 1000:
            return OTHER
        return ABSTAIN
    
    
    def area(bbox):
        return (bbox[YMAX] - bbox[YMIN]) * (bbox[XMAX] - bbox[XMIN])
    
    
    # Size-based LF
    @labeling_function()
    def lf_area(x):
        if area(x.subject_bbox) / area(x.object_bbox) <= 0.5:
            return OTHER
        return ABSTAIN
    
  • 标记函数具有不同的经验准确性和覆盖范围。由於我们选择的关系中的类别不平衡,标记 OTHER 的标记函数比RIDE或CARRY的标记函数具有更高的覆盖率。这也反映了数据集中类的分布。

3. 训练标签模型

  • 训练 LabelModel 来为未标记的训练集分配训练标签。

    from snorkel.labeling.model import LabelModel
    
    label_model = LabelModel(cardinality=3, verbose=True)
    label_model.fit(
        L_train, 
        seed=123, 
        lr=0.01, 
        log_freq=10, 
        n_epochs=100
        )
    

4. 训练分类器

  • 现在,您可以使用这些训练标签来训练任何标准判别模型,例如现成的 ResNet,它应该学会在我们开发的 LF 之外进行泛化。

    from snorkel.classification import DictDataLoader
    from model import SceneGraphDataset, create_model
    
    df_train["labels"] = label_model.predict(L_train)
    
    if sample:
        TRAIN_DIR = "data/VRD/sg_dataset/samples"
    else:
        TRAIN_DIR = "data/VRD/sg_dataset/sg_train_images"
    
    dl_train = DictDataLoader(
        SceneGraphDataset("train_dataset", "train", TRAIN_DIR, df_train),
        batch_size=16,
        shuffle=True,
    )
    
    dl_valid = DictDataLoader(
        SceneGraphDataset("valid_dataset", "valid", TRAIN_DIR, df_valid),
        batch_size=16,
        shuffle=False,
    )
    
  • 定义模型架构。

    import torchvision.models as models
    
    # initialize pretrained feature extractor
    cnn = models.resnet18(pretrained=True)
    model = create_model(cnn)
    
    • 训练与评估模型
    from snorkel.classification import Trainer
    
    trainer = Trainer(
        n_epochs=1,  # increase for improved performance
        lr=1e-3,
        checkpointing=True,
        checkpointer_config={"checkpoint_dir": "checkpoint"},
    )
    trainer.fit(model, [dl_train])
    
    model.score([dl_valid])
    # {'visual_relation_task/valid_dataset/valid/f1_micro':
    #  0.34615384615384615}
    
  • 我们已经成功训练了一个视觉关系检测模型!使用关於视觉关系中的对像如何相互作用的分类和空间直觉,我们能够在多类分类设置中为 VRD 数据集中的对像对分配高质量的训练标签。

  • 有关 Snorkel 如何用於视觉关系任务的更多信息,请参阅该团队 ICCV 2019 论文

小结

  • 这一篇是笔者想确认如何用弱监督的方式完成影像标注,Snorkel 确实做到了,但可惜的是相依模组版本比较旧,在 Colab 实现需要些调整,笔者调整後就分享给有兴趣的人。
  • 现在无监督式学习兴起,但如果有需要退而求其次自己写条件时,Snorkel 应可帮助到您。

参考


<<:  从 JavaScript 角度学 Python(12) - 运算子

>>:  Flutter体验 Day 5-Widget 乐高积木

JavaScript | By value V.S. By reference(传值 V.S. 传参考)

传值 (By value) 针对一个变数的纯值(Number, String, Boolean, u...

[Day3]-if叙述

if…else叙述,基本格式为: if(条件): 条件成立时执行 else: 条件不成立时执行 i...

了解内嵌、外嵌导入css方式

进入到css的环节,讲解如何使用内嵌、外嵌导入css的方式 要注意内嵌、外嵌使用方式不太一样 外嵌...

Day28 D3js Diagram常见的两点浪漫路径

D3js Diagram常见的两点浪漫路径 用途 在绘制diagram图表时,会用到的垂直水平连线,...

Day27 Data Storage in iOS 03 - File System & Sqlite

File System Apple 让iOS 应用程序内的文件编写、阅读和编辑变得非常容易。每个应用...