什么是PyTorch
PyTorch是一个Python库,并不是某些庞大的C++框架的PyThon接口,PyTorch从底层就与Python紧密结合,你可以像使用numpy/scipy/scikit-learn一样自然地使用PyTorch。可以用Python实现一个神经网络地层,能使用各种Python的优化库,如Cython,Numba等。
PyTorch认同:在适当的时候,不要重新发明轮子。
是一个Python包,提供两个高级功能:
- 具有强大的GPU加速的张量计算(如NumPy)
- 包含自动求导系统的深度神经网络
基础概念
科学计算工具包,一般有两个问题:
怎么定义数据?(怎么存数据,拿什么容器存数据)—-torch.Tensor
怎么定义数据操作?(怎么对数据进行操作)—-torch.autograd.Function
torch.Tensor:张量,用来存储数据,是各种类型的数据的封装。
torch.autograd.Function:函数类,是定义在Tensor类上的操作,从加减乘除到矩阵计算,应有尽有。
计算图(计算过程)分为静态图和动态图
Tensor:data存数据,grad存导数,grad_fn指向创造自己的Function,用户创建的为None
可以通过grad_fn来计算偏导数(反向传播)
常见的Research Workflow
idea-设计实验-处理数据-实现模型- 训练测试-写作
常见的代码实现流程
加载、预处理数据集-构建模型-定义损失函数-实现优化算法-迭代训练-加速计算(GPU)-存储模型-构建baseline
如何用PyTorch完成实验
PyTorch里有vision和text两套开源数据集。
快速构建自定义数据集
torch.utils.data
细粒度构建自定义数据集
``
加载数据集
torch.utils.data.Dataloader
为什么要加载:可以做batch、shuffle,准备怎么抽样,抽样出来的样本怎么处理。
1 | torch_utils_data_data_loader = Data.DataLoader(dataset = person_dataset, ) |
汉字不能直接放到数据集去运行,要变成数字或者tensor,这时候就要用到collate_fn