这篇文章主要讲述用 pytorch 完成简单 CNN 图片分类任务,如果想对 CNN 的理论知识进行了解,可以看我的这篇文章,深度学习(一)——CNN卷积神经网络。
图片分类
我们以美食图片分类为例,有testing、training、validation文件夹。下载链接放下面。
点击提取, 提取码:nefu
前面的 0 表示其为 0 类,后面为其编号。
导入必要的包
1 | # Import需要的套件 |
cv2 我是通过如下命令下载
1 | pip install opencv-python |
torch 我下载的是 cuda10.2 的版本,这里就简单放一下下载 pytorch 的代码,至于如何使用 GPU 加速,可以上网查查。
1 | pip3 install torch==1.10.0+cu102 torchvision==0.11.1+cu102 torchaudio===0.10.0+cu102 -f https://download.pytorch.org/whl/cu102/torch_stable.html |
读取数据
把训练集、验证集和测试集读取进来,放入 numpy 数组。 x 为其图片的像素张量,y 为其标签。
1 | # Read image 利用 OpenCV(cv2) 读入照片并存放在 numpy array 中 |
数据处理
定义数据增强操作(随机翻转、随机旋转),定义 batch 的大小。
1 | ''' Dataset ''' |
模型结构
定义CNN的结构。
1 | ''' Model ''' |
训练模型
对模型进行训练,迭代30次,并用验证集测试,最后将训练集和验证集合并在进行训练。
1 | ''' Training ''' |
Output:
1 | [001/030] 70.94 sec(s) Train Acc: 0.260997 Loss: 0.065946 | Val Acc: 0.303499 loss: 0.060955 |
测试
对测试集进行预测
1 | ''' Testing ''' |