1 Star 0 Fork 0

melon / diveDL

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
d2l_NiN_5_8.py 2.35 KB
一键复制 编辑 原始数据 按行查看 历史
melon 提交于 2023-10-10 20:13 . 添加VGG与NiN
#%% 串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络
# NiN使用1×1卷积层来替代全连接层
import time
import torch
from torch import nn, optim
import d2lzh_pytorch as d2l
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#%% NiN块
def nin_block(in_channels, out_channels, kernel_size, stride, padding):
blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.ReLU())
return blk
#%% 已保存在d2lzh_pytorch 全局平均池化层
class GlobalAvgPool2d(nn.Module):
# 全局平均池化层可通过将池化窗口形状 设置成输入的高和宽实现
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
return F.avg_pool2d(x, kernel_size=x.size()[2:])
#%%
net = nn.Sequential(
nin_block(1, 96, kernel_size=11, stride=4, padding=0),
nn.MaxPool2d(kernel_size=3, stride=2),
nin_block(96, 256, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=3, stride=2),
nin_block(256, 384, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Dropout(0.5),
# 标签类别数是10
nin_block(384, 10, kernel_size=3, stride=1, padding=1), # 输出通道数等于标签类别数的NiN块
GlobalAvgPool2d(), # 对每个通道中所有元素求平均并直接用于分类
# 将四维的输出转成二维的输出,其形状为(批量大小, 10)
d2l.FlattenLayer())
#%% 构建一个数据样本 查看每一层的输出形状
X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children():
X = blk(X)
print(name, 'output shape: ', X.shape)
#%% NiN的训练与AlexNet和VGG的类似,但这里使用的学习率更大
batch_size = 128
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
#%%
1
https://gitee.com/melon199/dive-dl.git
git@gitee.com:melon199/dive-dl.git
melon199
dive-dl
diveDL
master

搜索帮助