mnist.py
data
train.csv
test.csv
path = './data/train.csv'
# 0) data import and set as pytorch dataset
class MnistDataset(Dataset):
# data loading
def __init__(self, path):
xy = np.loadtxt(path, delimiter=',', dtype=np.float32, skiprows=1)
self.x = torch.from_numpy(xy[:, 1:])
self.y = torch.from_numpy(xy[:, [0]].flatten().astype(np.longlong))
self.n_samples = xy.shape[0]
# working for indexing
def __getitem__(self, index):
return self.x[index], self.y[index]
# return the length of our dataset
def __len__(self):
return self.n_samples
train_path = './data/train.csv'
test_path = './data/test.csv'
train_dataset = MnistDataset(train_path)
test_dataset = MnistDataset(test_path)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
,
去做切分的,因此我们设定我们的 delimiter 为 ,
.unsqueeze(0)
去作为度的扩展,利用 .view(28, 28)
去 reshape 资料# 0) data import and set as pytorch dataset
class MNISTDataset(Dataset):
# data loading
def __init__(self, path, transform=None):
xy = np.loadtxt(path, delimiter=',', dtype=np.float32, skiprows=1)
self.x = torch.from_numpy(xy[:, 1:])
self.y = torch.from_numpy(xy[:, [0]].flatten().astype(np.longlong))
self.n_samples = xy.shape[0]
self.transform = transform
# working for indexing
def __getitem__(self, index):
sample = self.x[index], self.y[index]
if self.transform:
sample = self.transform(sample)
return sample
# return the length of our dataset
def __len__(self):
return self.n_samples
class ToImage:
def __call__(self, sample):
inputs, targets = sample
return inputs.view(28, 28).unsqueeze(0), targets
train_path = './data/train.csv'
test_path = './data/test.csv'
train_dataset = MNISTDataset(train_path, transform=ToImage())
test_dataset = MNISTDataset(test_path, transform=ToImage())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
torchvision
函式库,那下面就让我们示范怎麽去使用 Pytorch 的 Dataset 去取得 MNIST 的资料吧~import torchvision
import torchvision.transforms as transforms
# MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False,
transform=transforms.ToTensor(), download=False)
torchvision.datasets.资料集名称
去取得,那分成 training data 跟 testing data 的方式非常简单,Pytorch 非常贴心的提供一个参数叫做 train
,来区分资料集中的 training data 跟 testing dataroot
参数说明了资料位置,那我们下面画一下资料的结构图mnist.py
data
MNIST
train-images-idx3-ubyte
train-labels-idx1-ubyte
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
transform
,这个参数定义了资料的转换,例如说今天的资料是图片资料,我们可以利用 transform 来做 img to tensor
的转换,因此我们可以看到在这边我们就挂了 transforms.ToTensor()
确保资料型态是 Tensortrain_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=True)
shuffle
,shuffle 参数允许我们在使用资料前优先打乱资料,那这个参数很好用,因为大部分现成的资料集内部资料都是有整理过的,也就是有序排列,也就是说,如果不先 shuffle 过资料的话,容易造成模型训练的偏差,因此如果需要的话,shuffle 参数都是可以随手挂载的,对训练会有帮助
<<: IT 铁人赛 k8s 入门30天 -- day27 Communicate Between Containers in the Same Pod Using a Shared Volume
>>: 予焦啦!结论与展望(二):铁人赛、正体中文科技写作与杂谈
数组 array(): 生成一个数组 range(): 创建并返回一个包含指定范围的元素的数组 co...
Redis持久化 Redis是一个in-memory的data store,在记忆体中操作与储存让其...
今天实作一个很简单的计数器,按下按钮後数字会一直累加1 要先将useState 汇入 import ...
物件宣告 物件内容为一个属性 (property)对应一个值 (value), 如果要在後方添加新的...
今天来处理登入的流程。 送出登入的方式很简单,使用WebSocketClient的send方法即可:...