a subject <predictate> object
。例如, person riding bicycle
, “person” 和 “bicycle” 分别是主词和受词, “riding” 是关系动词。model.py
遇到pandas.DataFrame.as_matrix()
是旧的语法,也提供修正方式pandas.DataFrame.values
。我们现在编写标记函数来检测边界框对之间存在什麽关系。为此,我们可以将各种直觉编码到标记函数中:
分类直觉:关於这些关系中通常涉及的主词和受词类别的知识(例如,person通常是谓词 RIDE 和的主词 CARRY)
空间直觉:关於主词和受词的相对位置的知识(例如,主词通常高於动词的受词RIDE)
RIDE = 0
CARRY = 1
OTHER = 2
ABSTAIN = -1
我们从编码分类直觉的标记函数开始:我们使用关於共同的主题-客体类别对的知识 RIDE,CARRY 以及关於哪些主题或客体不太可能涉及这两种关系的知识。
from snorkel.labeling import labeling_function
# Category-based LFs
@labeling_function()
def lf_ride_object(x):
if x.subject_category == "person":
if x.object_category in [
"bike",
"snowboard",
"motorcycle",
"horse",
"bus",
"truck",
"elephant",
]:
return RIDE
return ABSTAIN
@labeling_function()
def lf_carry_object(x):
if x.subject_category == "person":
if x.object_category in ["bag", "surfboard", "skis"]:
return CARRY
return ABSTAIN
@labeling_function()
def lf_carry_subject(x):
if x.object_category == "person":
if x.subject_category in ["chair", "bike", "snowboard", "motorcycle", "horse"]:
return CARRY
return ABSTAIN
@labeling_function()
def lf_not_person(x):
if x.subject_category != "person":
return OTHER
return ABSTAIN
现在编码空间直觉,其中包括测量边界框之间的距离并比较它们的相对区域。
YMIN = 0
YMAX = 1
XMIN = 2
XMAX = 3
import numpy as np
# Distance-based LFs
@labeling_function()
def lf_ydist(x):
if x.subject_bbox[XMAX] < x.object_bbox[XMAX]:
return OTHER
return ABSTAIN
@labeling_function()
def lf_dist(x):
if np.linalg.norm(np.array(x.subject_bbox) - np.array(x.object_bbox)) <= 1000:
return OTHER
return ABSTAIN
def area(bbox):
return (bbox[YMAX] - bbox[YMIN]) * (bbox[XMAX] - bbox[XMIN])
# Size-based LF
@labeling_function()
def lf_area(x):
if area(x.subject_bbox) / area(x.object_bbox) <= 0.5:
return OTHER
return ABSTAIN
标记函数具有不同的经验准确性和覆盖范围。由於我们选择的关系中的类别不平衡,标记 OTHER 的标记函数比RIDE或CARRY的标记函数具有更高的覆盖率。这也反映了数据集中类的分布。
训练 LabelModel
来为未标记的训练集分配训练标签。
from snorkel.labeling.model import LabelModel
label_model = LabelModel(cardinality=3, verbose=True)
label_model.fit(
L_train,
seed=123,
lr=0.01,
log_freq=10,
n_epochs=100
)
现在,您可以使用这些训练标签来训练任何标准判别模型,例如现成的 ResNet,它应该学会在我们开发的 LF 之外进行泛化。
from snorkel.classification import DictDataLoader
from model import SceneGraphDataset, create_model
df_train["labels"] = label_model.predict(L_train)
if sample:
TRAIN_DIR = "data/VRD/sg_dataset/samples"
else:
TRAIN_DIR = "data/VRD/sg_dataset/sg_train_images"
dl_train = DictDataLoader(
SceneGraphDataset("train_dataset", "train", TRAIN_DIR, df_train),
batch_size=16,
shuffle=True,
)
dl_valid = DictDataLoader(
SceneGraphDataset("valid_dataset", "valid", TRAIN_DIR, df_valid),
batch_size=16,
shuffle=False,
)
定义模型架构。
import torchvision.models as models
# initialize pretrained feature extractor
cnn = models.resnet18(pretrained=True)
model = create_model(cnn)
from snorkel.classification import Trainer
trainer = Trainer(
n_epochs=1, # increase for improved performance
lr=1e-3,
checkpointing=True,
checkpointer_config={"checkpoint_dir": "checkpoint"},
)
trainer.fit(model, [dl_train])
model.score([dl_valid])
# {'visual_relation_task/valid_dataset/valid/f1_micro':
# 0.34615384615384615}
我们已经成功训练了一个视觉关系检测模型!使用关於视觉关系中的对像如何相互作用的分类和空间直觉,我们能够在多类分类设置中为 VRD 数据集中的对像对分配高质量的训练标签。
有关 Snorkel 如何用於视觉关系任务的更多信息,请参阅该团队 ICCV 2019 论文。
<<: 从 JavaScript 角度学 Python(12) - 运算子
>>: Flutter体验 Day 5-Widget 乐高积木
传值 (By value) 针对一个变数的纯值(Number, String, Boolean, u...
if…else叙述,基本格式为: if(条件): 条件成立时执行 else: 条件不成立时执行 i...
进入到css的环节,讲解如何使用内嵌、外嵌导入css的方式 要注意内嵌、外嵌使用方式不太一样 外嵌...
D3js Diagram常见的两点浪漫路径 用途 在绘制diagram图表时,会用到的垂直水平连线,...
File System Apple 让iOS 应用程序内的文件编写、阅读和编辑变得非常容易。每个应用...