img_names = os.listdir(data_path)
source = img_names
alphabet = ''.join(source)
# 读取图档,并转换大小为80*80,以及转换成RGB
def img_loader(img_path):
image = Image.open(img_path)
img = image.resize((80, 80),Image.ANTIALIAS) #resize image with high-quality
return img.convert('RGB')
# 将图档与label对应,丢入自定义的资料集内
def make_dataset(data_path, alphabet, num_class):
samples = []
for i in os.listdir(data_path):
for j in os.listdir(data_path + '/' + i):
img_path = data_path + '/' + i + '/' + j
target_str = j.split('.')[0][-1]
vec = [0] * 800
vec[alphabet.find(target_str)] = 1
target = vec
samples.append((img_path, target))
return samples
例如这个字是"不",由alphabet的位置可以看到alphabet[3]的位置是"不",故在alphabet[3]的位置为1,代表他的label,其余位置皆为0。
torch.utils.data.Dataset,是一个自定义资料集的框架。
__ init __()
def __init__(self, data_path, num_class=800,transform=None,target_transform=None, alphabet=alphabet):
super(Dataset, self).__init__()
self.data_path = data_path
self.num_class = num_class
self.transform = transform
self.target_transform = target_transform
self.alphabet = alphabet
self.samples = make_data.set(self.data_path, self.alphabet)
__ len __ ()
def __len__(self):
return len(self.samples)
__ getitem __ ()
def __getitem__(self, index):
img_path, target = self.samples[index]
img = img_loader(img_path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, torch.Tensor(target) # 在torch里面,array都要转成Tensor型式
完整程序码
class CaptchaData(Dataset):
def __init__(self, data_path, num_class=800,
transform=None, target_transform=None, alphabet=alphabet):
super(Dataset, self).__init__()
self.data_path = data_path
self.num_class = num_class
self.transform = transform
self.target_transform = target_transform
self.alphabet = alphabet
self.samples = make_dataset(self.data_path, self.alphabet)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
img_path, target = self.samples[index]
img = img_loader(img_path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, torch.Tensor(target)
torch.utils.data.DataLoader
Dataset设置好後,DataLoader可以依照batch_size让我们取样,非常方便。
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader
transforms = Compose([ToTensor()])
train_dataset = CaptchaData(r'C:\Users\Frank\PycharmProjects\practice\mountain\data_final_20210530\official_in_800',transform=transforms)
train_data_loader = DataLoader(train_dataset, batch_size=1, num_workers=0,
shuffle=True, drop_last=True)
for (data,label) in train_data_loader:
print((data,label))
我把batch_size设定为1,他一次就只取出一组图片样本及标签。
除了自定义资料集以外,还有可以torchvision.datasets.ImageFolder
来处理资料集,用法会在於你分好类别,他的资料夹名称就是他的label,而里面图片都属於这个label。
深度学习有很多很好玩的地方,但也有很多的坑,debug我都要找很久XDD,重点是东西太多,绝对学不完,而且很吃硬体设备。有时候会觉得自己好笨,都学不会,但看久了发现懂一点了,就又有动力继续往下学了,接触深度学习的朋友们,我们一起继续努力吧!
小弟我是试着用自定义资料集来处理,原因只想练习以及可以更弹性的操作载入资料的动作。
前面加载图片时我们把transforms设置为None,现在我们丢模型训练要对图片做transforms,他可以增加图片的多样性,例如:旋转、平移、变形等等,明天来跟大家分享torchvision很好用的套件transforms。
<<: Day15:终於要进去新手村了-Javascript-isNaN函式
嗨,各位早安 相信各位一定有听过"防火墙"这个东西吧? 之前我们前面讲过了ssh...
前言 本篇文章,是要设定系统的时区与时间,并维护时区与时间的设定。 设定本地时间与时区 作业系统的时...
今天我们来讲一下如何使用Parrot_Security的CeWL工具来收集特定网站的Wordlist...
在经历上一部函数与类别的摧残後,这两天就来教一些比较温和的程序吧~ 今天的影片内容为介绍常见的档案格...
前言 今天文章的标题完完全全打脸了笔者在 Day27 的结语,没想到在最後一天仍然还是介绍了早午餐给...