本文共 1883 字,大约阅读时间需要 6 分钟。
import numpy as npimport torchimport torch.utils.data as Data# 创建批量生成函数def load_dataset(data_file, batch_size): """数据批量生成函数 Args: data_file ([type]): 处理的文件,这里的文件是用np.savez()保存的npz文件 根据不同的数据样式进行区别处理; batch_size ([type]): 每个批次样本数量 """ data = np.load(data_file) # 分别提取data中的特征和标签 x_data = data["x_data"] y_data = data['y_data'] # 将数据封装成tensor张量 x = torch.tensor(x_data, dtype=torch.long) y = torch.tensor(y_data, dtype=torch.long) # 将数据再次封装 dataset = Data.TensorDataset(x, y) # 求解数据的总量 total_length = len(dataset) # 将80%的数据作为训练集, 将生于20%的数据作为测试 train_length = int(total_length*0.8) validation_length = total_length - train_length # 利用Data.random_split()切分数据, 按照80%和20%的比例 train_dataset, validation_dataset = Data.random_split(dataset=dataset,lengths=[train_length, validation_length]) # 将训练集进行Dataloader封装 # 参数说明如下 # dataset: 训练数据集 # batch_size: 代表批次大小,若数据集总样本数量无法被batch_size整除,则最后一批数据为余数 # 设置drop_last=True时,自动抹去最后不能被整除的剩余批次 # shuffle:是否为每个批次随机抽取, True时为随机抽取 # num_workers: 设定用多少个子进程加载数据,默认为0,数据加载到主进程中 # drop_last: 是否去除不能被整除的最后的批次,True时,舍弃最后的数据 train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True) validation_loader = Data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True) # 将两个数据生成器封装成一个字典类型 data_loaders = { "train":train_loader, "validation":validation_loader} # 将两个数据集的长度也封装成一个字典类型 data_size = { "train":train_length, "validation":validation_length} return data_loaders, data_size def call_data_loader(): BATCH_SIZE = 8 DATA_FILE = "data/train.npz" data_loader, data_size = load_dataset(DATA_FILE, BATCH_SIZE) print("data_loader:", data_loader, 'data_size', data_size)if __name__ == "__main__": call_data_loader()
转载地址:http://vvnws.baihongyu.com/