博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
数据批量生成函数代码实现
阅读量:4298 次
发布时间:2019-05-27

本文共 1883 字,大约阅读时间需要 6 分钟。

import numpy as npimport torchimport torch.utils.data as Data# 创建批量生成函数def load_dataset(data_file, batch_size):    """数据批量生成函数    Args:        data_file ([type]): 处理的文件,这里的文件是用np.savez()保存的npz文件        根据不同的数据样式进行区别处理;        batch_size ([type]): 每个批次样本数量    """    data = np.load(data_file)    # 分别提取data中的特征和标签    x_data = data["x_data"]    y_data = data['y_data']    # 将数据封装成tensor张量    x = torch.tensor(x_data, dtype=torch.long)    y = torch.tensor(y_data, dtype=torch.long)    # 将数据再次封装    dataset = Data.TensorDataset(x, y)    # 求解数据的总量    total_length = len(dataset)    # 将80%的数据作为训练集, 将生于20%的数据作为测试    train_length = int(total_length*0.8)    validation_length = total_length - train_length    # 利用Data.random_split()切分数据, 按照80%和20%的比例    train_dataset, validation_dataset = Data.random_split(dataset=dataset,lengths=[train_length, validation_length])    # 将训练集进行Dataloader封装    # 参数说明如下    # dataset: 训练数据集    # batch_size: 代表批次大小,若数据集总样本数量无法被batch_size整除,则最后一批数据为余数    #            设置drop_last=True时,自动抹去最后不能被整除的剩余批次    # shuffle:是否为每个批次随机抽取, True时为随机抽取    # num_workers: 设定用多少个子进程加载数据,默认为0,数据加载到主进程中    # drop_last: 是否去除不能被整除的最后的批次,True时,舍弃最后的数据    train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)    validation_loader = Data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)    # 将两个数据生成器封装成一个字典类型    data_loaders = {
"train":train_loader, "validation":validation_loader} # 将两个数据集的长度也封装成一个字典类型 data_size = {
"train":train_length, "validation":validation_length} return data_loaders, data_size def call_data_loader(): BATCH_SIZE = 8 DATA_FILE = "data/train.npz" data_loader, data_size = load_dataset(DATA_FILE, BATCH_SIZE) print("data_loader:", data_loader, 'data_size', data_size)if __name__ == "__main__": call_data_loader()

转载地址:http://vvnws.baihongyu.com/

你可能感兴趣的文章
C指针声明解读之左右法则
查看>>
一个异步网络请求的坑:关于NSURLConnection和NSRunLoopCommonModes
查看>>
iOS 如何放大按钮点击热区
查看>>
ios设备唯一标识获取策略
查看>>
获取推送通知的DeviceToken
查看>>
Could not find a storyboard named 'Main' in bundle NSBundle
查看>>
CocoaPods安装和使用教程
查看>>
Beginning Auto Layout Tutorial
查看>>
block使用小结、在arc中使用block、如何防止循环引用
查看>>
iPhone开发学习笔记002——Xib设计UITableViewCell然后动态加载
查看>>
iOS开发中遇到的问题整理 (一)
查看>>
Swift code into Object-C 出现 ***-swift have not found this file 的问题
查看>>
为什么你的App介绍写得像一坨翔?
查看>>
RTImageAssets插件--@3x可自动生成@2x图片
查看>>
iOS开发的一些奇巧淫技
查看>>
常浏览的博客和网站
查看>>
Xcode 工程文件打开不出来, cannot be opened because the project file cannot be parsed.
查看>>
点击button实现Storyboard中TabBar Controller的tab切换
查看>>
Xcode 的正确打开方式——Debugging
查看>>
打包app出现的一个问题
查看>>