ReXNet:消除表达瓶颈,进来唠唠网络设计那些事

ReXNet

CVPR2021的文章,代码不长,简单复现了一下,重点应该是对网络设计模型的一些思考,进来咱们唠唠网络模型一些思想

  • 基于原Pytorch,更少改动

  • 本文更多是吸收网络设计一些有用的思想

  • 论文最后部分给出他们为什么精度这么好,原来是各种trick都用上了,只能说,不讲武德

从MoblieNet V2开始谈起

首先我们看看这张图,什么意思呢,用一句话概括就是,当输入的维度低时候,经过ReLU这样的非线性函数后,会损失很多的信息,当输入的维度足够高,经过非线性函数后损失的信息更少

基于此,MoblieNet V2提出Linear bottleneck和Inverted residuals

Linear bottleneck

输入通过 1x1 conv 扩大通道,然后进行 Depthwish conv,然后通过1x1 conv 降维

这三个操作中,前面两个之后会接ReLU6,最后一个因为输出维度低,不接ReLU6

Inverted residuals

一个是中间瘦两边胖,一个是中间胖两边瘦

所以综上所述,基本block如下,分为 stride=1 的block和 stride=2 的下采样block,注意一点,最后的 1x1 没有接ReLU6

模型结构图

而ReXNet,就是基于上面网络设计不足之处进行再一次改进

ReXNet有趣的一些思想

ReXNet主要思想是消除表征瓶颈,这里直接放总结,也是作者提出的设计原则

  • 扩展输入channel大小

    通过Softmax函数的瓶颈,作者联想到层瓶颈问题,即当输入维度小于输出维度时候,输入的低秩性无法表示高秩空间,比如你输入2个维度,输出10维度分类,2个维度的数据不好表示10维度的分类

    作者认为输入维度和输出维度应满足以下不等式,使得层瓶颈影响更小

  • 适当的激活函数——Swish-1

    建议看看这篇文章: https://arxiv.org/abs/1710.05941

    实验证明Swish函数具有更高的秩,可以提升数据的秩,让输入的秩更能接近输出的秩,从而减小层瓶颈

  • 多个扩展层,通道逐步递进

    扩展层是输出channel大于输入channel的层,多个扩展层可以防止输入维度和输出维度秩相差太大,逐步推进更好减少层表达瓶颈

对比一下上面MoblieNet V2网络结构,你发现什么不同了吗

  • 多个扩展层
  • 输入通道和输出通道递层增加,相差不大
  • 引入Swish-1函数,在block里第一个1x1后

完整代码

1. 模型建立

import paddle
import paddle.nn as nn
from math import ceil

print(paddle.__version__)
2.0.1
def ConvBNAct(out, in_channels, channels, kernel=1, stride=1, pad=0,
              num_group=1, active=True, relu6=False):
    out.append(nn.Conv2D(in_channels, channels, kernel,
                         stride, pad, groups=num_group, bias_attr=False))
    out.append(nn.BatchNorm2D(channels))
    if active:
        out.append(nn.ReLU6() if relu6 else nn.ReLU())


def ConvBNSwish(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1):
    out.append(nn.Conv2D(in_channels, channels, kernel,
                         stride, pad, groups=num_group, bias_attr=False))
    out.append(nn.BatchNorm2D(channels))
    out.append(nn.Swish())


class SE(nn.Layer):
    def __init__(self, in_channels, channels, se_ratio=12):
        super(SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)
        self.fc = nn.Sequential(
            nn.Conv2D(in_channels, channels // se_ratio, kernel_size=1, padding=0),
            nn.BatchNorm2D(channels // se_ratio),
            nn.ReLU(),
            nn.Conv2D(channels // se_ratio, channels, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y


class LinearBottleneck(nn.Layer):
    def __init__(self, in_channels, channels, t, stride, use_se=True, se_ratio=12,
                 **kwargs):
        super(LinearBottleneck, self).__init__(**kwargs)
        self.use_shortcut = stride == 1 and in_channels <= channels
        self.in_channels = in_channels
        self.out_channels = channels

        out = []
        if t != 1:
            dw_channels = in_channels * t
            ConvBNSwish(out, in_channels=in_channels, channels=dw_channels)
        else:
            dw_channels = in_channels

        ConvBNAct(out, in_channels=dw_channels, channels=dw_channels, kernel=3, stride=stride, pad=1,
                  num_group=dw_channels, active=False)

        if use_se:
            out.append(SE(dw_channels, dw_channels, se_ratio))

        out.append(nn.ReLU6())
        ConvBNAct(out, in_channels=dw_channels, channels=channels, active=False, relu6=True)
        self.out = nn.Sequential(*out)

    def forward(self, x):
        out = self.out(x)
        if self.use_shortcut:
            out[:, 0:self.in_channels] += x

        return out


class ReXNetV1(nn.Layer):
    def __init__(self, input_ch=16, final_ch=180, width_mult=1.0, depth_mult=1.0, classes=1000,
                 use_se=True,
                 se_ratio=12,
                 dropout_ratio=0.2,
                 bn_momentum=0.9):
        super(ReXNetV1, self).__init__()

        layers = [1, 2, 2, 3, 3, 5]
        strides = [1, 2, 2, 2, 1, 2]
        use_ses = [False, False, True, True, True, True]

        layers = [ceil(element * depth_mult) for element in layers]
        strides = sum([[element] + [1] * (layers[idx] - 1)
                       for idx, element in enumerate(strides)], [])
        if use_se:
            use_ses = sum([[element] * layers[idx] for idx, element in enumerate(use_ses)], [])
        else:
            use_ses = [False] * sum(layers[:])
        ts = [1] * layers[0] + [6] * sum(layers[1:])

        self.depth = sum(layers[:]) * 3
        stem_channel = 32 / width_mult if width_mult < 1.0 else 32
        inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch

        features = []
        in_channels_group = []
        channels_group = []


        for i in range(self.depth // 3):
            if i == 0:
                in_channels_group.append(int(round(stem_channel * width_mult)))
                channels_group.append(int(round(inplanes * width_mult)))
            else:
                in_channels_group.append(int(round(inplanes * width_mult)))
                inplanes += final_ch / (self.depth // 3 * 1.0)
                channels_group.append(int(round(inplanes * width_mult)))

        ConvBNSwish(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1)

        for block_idx, (in_c, c, t, s, se) in enumerate(zip(in_channels_group, channels_group, ts, strides, use_ses)):
            features.append(LinearBottleneck(in_channels=in_c,
                                             channels=c,
                                             t=t,
                                             stride=s,
                                             use_se=se, se_ratio=se_ratio))

        pen_channels = int(1280 * width_mult)
        ConvBNSwish(features, c, pen_channels)

        features.append(nn.AdaptiveAvgPool2D(1))
        self.features = nn.Sequential(*features)
        self.output = nn.Sequential(
            nn.Dropout(dropout_ratio),
            nn.Conv2D(pen_channels, classes, 1, bias_attr=True))

    def forward(self, x):
        x = self.features(x)
        x = self.output(x).squeeze()
        return x
    
rexnet=ReXNetV1(classes=10)
rexnet
ReXNetV1(
  (features): Sequential(
    (0): Conv2D(3, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
    (1): BatchNorm2D(num_features=32, momentum=0.9, epsilon=1e-05)
    (2): Swish()
    (3): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(32, 32, kernel_size=[3, 3], padding=1, groups=32, data_format=NCHW)
        (1): BatchNorm2D(num_features=32, momentum=0.9, epsilon=1e-05)
        (2): ReLU6()
        (3): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)
        (4): BatchNorm2D(num_features=16, momentum=0.9, epsilon=1e-05)
      )
    )
    (4): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=96, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(96, 96, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=96, data_format=NCHW)
        (4): BatchNorm2D(num_features=96, momentum=0.9, epsilon=1e-05)
        (5): ReLU6()
        (6): Conv2D(96, 27, kernel_size=[1, 1], data_format=NCHW)
        (7): BatchNorm2D(num_features=27, momentum=0.9, epsilon=1e-05)
      )
    )
    (5): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(27, 162, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=162, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(162, 162, kernel_size=[3, 3], padding=1, groups=162, data_format=NCHW)
        (4): BatchNorm2D(num_features=162, momentum=0.9, epsilon=1e-05)
        (5): ReLU6()
        (6): Conv2D(162, 38, kernel_size=[1, 1], data_format=NCHW)
        (7): BatchNorm2D(num_features=38, momentum=0.9, epsilon=1e-05)
      )
    )
    (6): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(38, 228, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=228, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(228, 228, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=228, data_format=NCHW)
        (4): BatchNorm2D(num_features=228, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(228, 19, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=19, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(19, 228, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(228, 50, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=50, momentum=0.9, epsilon=1e-05)
      )
    )
    (7): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(50, 300, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=300, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(300, 300, kernel_size=[3, 3], padding=1, groups=300, data_format=NCHW)
        (4): BatchNorm2D(num_features=300, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(300, 25, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=25, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(25, 300, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(300, 61, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=61, momentum=0.9, epsilon=1e-05)
      )
    )
    (8): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(61, 366, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=366, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(366, 366, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=366, data_format=NCHW)
        (4): BatchNorm2D(num_features=366, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(366, 30, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=30, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(30, 366, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(366, 72, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=72, momentum=0.9, epsilon=1e-05)
      )
    )
    (9): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(72, 432, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=432, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(432, 432, kernel_size=[3, 3], padding=1, groups=432, data_format=NCHW)
        (4): BatchNorm2D(num_features=432, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(432, 36, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=36, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(36, 432, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(432, 84, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=84, momentum=0.9, epsilon=1e-05)
      )
    )
    (10): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(84, 504, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=504, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(504, 504, kernel_size=[3, 3], padding=1, groups=504, data_format=NCHW)
        (4): BatchNorm2D(num_features=504, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(504, 42, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=42, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(42, 504, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(504, 95, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=95, momentum=0.9, epsilon=1e-05)
      )
    )
    (11): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(95, 570, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=570, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(570, 570, kernel_size=[3, 3], padding=1, groups=570, data_format=NCHW)
        (4): BatchNorm2D(num_features=570, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(570, 47, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=47, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(47, 570, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(570, 106, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=106, momentum=0.9, epsilon=1e-05)
      )
    )
    (12): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(106, 636, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=636, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(636, 636, kernel_size=[3, 3], padding=1, groups=636, data_format=NCHW)
        (4): BatchNorm2D(num_features=636, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(636, 53, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=53, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(53, 636, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(636, 117, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=117, momentum=0.9, epsilon=1e-05)
      )
    )
    (13): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(117, 702, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=702, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(702, 702, kernel_size=[3, 3], padding=1, groups=702, data_format=NCHW)
        (4): BatchNorm2D(num_features=702, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(702, 58, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=58, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(58, 702, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(702, 128, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=128, momentum=0.9, epsilon=1e-05)
      )
    )
    (14): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(128, 768, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=768, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(768, 768, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=768, data_format=NCHW)
        (4): BatchNorm2D(num_features=768, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(768, 64, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=64, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(64, 768, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(768, 140, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=140, momentum=0.9, epsilon=1e-05)
      )
    )
    (15): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(140, 840, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=840, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(840, 840, kernel_size=[3, 3], padding=1, groups=840, data_format=NCHW)
        (4): BatchNorm2D(num_features=840, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(840, 70, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=70, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(70, 840, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(840, 151, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=151, momentum=0.9, epsilon=1e-05)
      )
    )
    (16): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(151, 906, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=906, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(906, 906, kernel_size=[3, 3], padding=1, groups=906, data_format=NCHW)
        (4): BatchNorm2D(num_features=906, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(906, 75, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=75, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(75, 906, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(906, 162, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=162, momentum=0.9, epsilon=1e-05)
      )
    )
    (17): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(162, 972, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=972, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(972, 972, kernel_size=[3, 3], padding=1, groups=972, data_format=NCHW)
        (4): BatchNorm2D(num_features=972, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(972, 81, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=81, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(81, 972, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(972, 174, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=174, momentum=0.9, epsilon=1e-05)
      )
    )
    (18): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(174, 1044, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=1044, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(1044, 1044, kernel_size=[3, 3], padding=1, groups=1044, data_format=NCHW)
        (4): BatchNorm2D(num_features=1044, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(1044, 87, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=87, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(87, 1044, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(1044, 185, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=185, momentum=0.9, epsilon=1e-05)
      )
    )
    (19): Conv2D(185, 1280, kernel_size=[1, 1], data_format=NCHW)
    (20): BatchNorm2D(num_features=1280, momentum=0.9, epsilon=1e-05)
    (21): Swish()
    (22): AdaptiveAvgPool2D(output_size=1)
  )
  (output): Sequential(
    (0): Dropout(p=0.2, axis=None, mode=upscale_in_train)
    (1): Conv2D(1280, 10, kernel_size=[1, 1], data_format=NCHW)
  )
)

2. 数据准备

采用Cifar10数据集,无过多的数据增强

import paddle.vision.transforms as T
from paddle.vision.datasets import Cifar10

#数据准备
transform = T.Compose([
    T.Resize(size=(224,224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
    T.ToTensor()
])

train_dataset = Cifar10(mode='train', transform=transform)
val_dataset = Cifar10(mode='test',  transform=transform)

model=paddle.Model(rexnet)
model.summary((1,3,224,224))
--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
================================================================================
      Conv2D-1        [[1, 3, 224, 224]]   [1, 32, 112, 112]         864      
   BatchNorm2D-1     [[1, 32, 112, 112]]   [1, 32, 112, 112]         128      
      Swish-1        [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
      Conv2D-2       [[1, 32, 112, 112]]   [1, 32, 112, 112]         288      
   BatchNorm2D-2     [[1, 32, 112, 112]]   [1, 32, 112, 112]         128      
      ReLU6-1        [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
      Conv2D-3       [[1, 32, 112, 112]]   [1, 16, 112, 112]         512      
   BatchNorm2D-3     [[1, 16, 112, 112]]   [1, 16, 112, 112]         64       
 LinearBottleneck-1  [[1, 32, 112, 112]]   [1, 16, 112, 112]          0       
      Conv2D-4       [[1, 16, 112, 112]]   [1, 96, 112, 112]        1,536     
   BatchNorm2D-4     [[1, 96, 112, 112]]   [1, 96, 112, 112]         384      
      Swish-2        [[1, 96, 112, 112]]   [1, 96, 112, 112]          0       
      Conv2D-5       [[1, 96, 112, 112]]    [1, 96, 56, 56]          864      
   BatchNorm2D-5      [[1, 96, 56, 56]]     [1, 96, 56, 56]          384      
      ReLU6-2         [[1, 96, 56, 56]]     [1, 96, 56, 56]           0       
      Conv2D-6        [[1, 96, 56, 56]]     [1, 27, 56, 56]         2,592     
   BatchNorm2D-6      [[1, 27, 56, 56]]     [1, 27, 56, 56]          108      
 LinearBottleneck-2  [[1, 16, 112, 112]]    [1, 27, 56, 56]           0       
      Conv2D-7        [[1, 27, 56, 56]]     [1, 162, 56, 56]        4,374     
   BatchNorm2D-7      [[1, 162, 56, 56]]    [1, 162, 56, 56]         648      
      Swish-3         [[1, 162, 56, 56]]    [1, 162, 56, 56]          0       
      Conv2D-8        [[1, 162, 56, 56]]    [1, 162, 56, 56]        1,458     
   BatchNorm2D-8      [[1, 162, 56, 56]]    [1, 162, 56, 56]         648      
      ReLU6-3         [[1, 162, 56, 56]]    [1, 162, 56, 56]          0       
      Conv2D-9        [[1, 162, 56, 56]]    [1, 38, 56, 56]         6,156     
   BatchNorm2D-9      [[1, 38, 56, 56]]     [1, 38, 56, 56]          152      
 LinearBottleneck-3   [[1, 27, 56, 56]]     [1, 38, 56, 56]           0       
     Conv2D-10        [[1, 38, 56, 56]]     [1, 228, 56, 56]        8,664     
   BatchNorm2D-10     [[1, 228, 56, 56]]    [1, 228, 56, 56]         912      
      Swish-4         [[1, 228, 56, 56]]    [1, 228, 56, 56]          0       
     Conv2D-11        [[1, 228, 56, 56]]    [1, 228, 28, 28]        2,052     
   BatchNorm2D-11     [[1, 228, 28, 28]]    [1, 228, 28, 28]         912      
AdaptiveAvgPool2D-1   [[1, 228, 28, 28]]     [1, 228, 1, 1]           0       
     Conv2D-12         [[1, 228, 1, 1]]      [1, 19, 1, 1]          4,351     
   BatchNorm2D-12      [[1, 19, 1, 1]]       [1, 19, 1, 1]           76       
       ReLU-1          [[1, 19, 1, 1]]       [1, 19, 1, 1]            0       
     Conv2D-13         [[1, 19, 1, 1]]       [1, 228, 1, 1]         4,560     
     Sigmoid-1         [[1, 228, 1, 1]]      [1, 228, 1, 1]           0       
        SE-1          [[1, 228, 28, 28]]    [1, 228, 28, 28]          0       
      ReLU6-4         [[1, 228, 28, 28]]    [1, 228, 28, 28]          0       
     Conv2D-14        [[1, 228, 28, 28]]    [1, 50, 28, 28]        11,400     
   BatchNorm2D-13     [[1, 50, 28, 28]]     [1, 50, 28, 28]          200      
 LinearBottleneck-4   [[1, 38, 56, 56]]     [1, 50, 28, 28]           0       
     Conv2D-15        [[1, 50, 28, 28]]     [1, 300, 28, 28]       15,000     
   BatchNorm2D-14     [[1, 300, 28, 28]]    [1, 300, 28, 28]        1,200     
      Swish-5         [[1, 300, 28, 28]]    [1, 300, 28, 28]          0       
     Conv2D-16        [[1, 300, 28, 28]]    [1, 300, 28, 28]        2,700     
   BatchNorm2D-15     [[1, 300, 28, 28]]    [1, 300, 28, 28]        1,200     
AdaptiveAvgPool2D-2   [[1, 300, 28, 28]]     [1, 300, 1, 1]           0       
     Conv2D-17         [[1, 300, 1, 1]]      [1, 25, 1, 1]          7,525     
   BatchNorm2D-16      [[1, 25, 1, 1]]       [1, 25, 1, 1]           100      
       ReLU-2          [[1, 25, 1, 1]]       [1, 25, 1, 1]            0       
     Conv2D-18         [[1, 25, 1, 1]]       [1, 300, 1, 1]         7,800     
     Sigmoid-2         [[1, 300, 1, 1]]      [1, 300, 1, 1]           0       
        SE-2          [[1, 300, 28, 28]]    [1, 300, 28, 28]          0       
      ReLU6-5         [[1, 300, 28, 28]]    [1, 300, 28, 28]          0       
     Conv2D-19        [[1, 300, 28, 28]]    [1, 61, 28, 28]        18,300     
   BatchNorm2D-17     [[1, 61, 28, 28]]     [1, 61, 28, 28]          244      
 LinearBottleneck-5   [[1, 50, 28, 28]]     [1, 61, 28, 28]           0       
     Conv2D-20        [[1, 61, 28, 28]]     [1, 366, 28, 28]       22,326     
   BatchNorm2D-18     [[1, 366, 28, 28]]    [1, 366, 28, 28]        1,464     
      Swish-6         [[1, 366, 28, 28]]    [1, 366, 28, 28]          0       
     Conv2D-21        [[1, 366, 28, 28]]    [1, 366, 14, 14]        3,294     
   BatchNorm2D-19     [[1, 366, 14, 14]]    [1, 366, 14, 14]        1,464     
AdaptiveAvgPool2D-3   [[1, 366, 14, 14]]     [1, 366, 1, 1]           0       
     Conv2D-22         [[1, 366, 1, 1]]      [1, 30, 1, 1]         11,010     
   BatchNorm2D-20      [[1, 30, 1, 1]]       [1, 30, 1, 1]           120      
       ReLU-3          [[1, 30, 1, 1]]       [1, 30, 1, 1]            0       
     Conv2D-23         [[1, 30, 1, 1]]       [1, 366, 1, 1]        11,346     
     Sigmoid-3         [[1, 366, 1, 1]]      [1, 366, 1, 1]           0       
        SE-3          [[1, 366, 14, 14]]    [1, 366, 14, 14]          0       
      ReLU6-6         [[1, 366, 14, 14]]    [1, 366, 14, 14]          0       
     Conv2D-24        [[1, 366, 14, 14]]    [1, 72, 14, 14]        26,352     
   BatchNorm2D-21     [[1, 72, 14, 14]]     [1, 72, 14, 14]          288      
 LinearBottleneck-6   [[1, 61, 28, 28]]     [1, 72, 14, 14]           0       
     Conv2D-25        [[1, 72, 14, 14]]     [1, 432, 14, 14]       31,104     
   BatchNorm2D-22     [[1, 432, 14, 14]]    [1, 432, 14, 14]        1,728     
      Swish-7         [[1, 432, 14, 14]]    [1, 432, 14, 14]          0       
     Conv2D-26        [[1, 432, 14, 14]]    [1, 432, 14, 14]        3,888     
   BatchNorm2D-23     [[1, 432, 14, 14]]    [1, 432, 14, 14]        1,728     
AdaptiveAvgPool2D-4   [[1, 432, 14, 14]]     [1, 432, 1, 1]           0       
     Conv2D-27         [[1, 432, 1, 1]]      [1, 36, 1, 1]         15,588     
   BatchNorm2D-24      [[1, 36, 1, 1]]       [1, 36, 1, 1]           144      
       ReLU-4          [[1, 36, 1, 1]]       [1, 36, 1, 1]            0       
     Conv2D-28         [[1, 36, 1, 1]]       [1, 432, 1, 1]        15,984     
     Sigmoid-4         [[1, 432, 1, 1]]      [1, 432, 1, 1]           0       
        SE-4          [[1, 432, 14, 14]]    [1, 432, 14, 14]          0       
      ReLU6-7         [[1, 432, 14, 14]]    [1, 432, 14, 14]          0       
     Conv2D-29        [[1, 432, 14, 14]]    [1, 84, 14, 14]        36,288     
   BatchNorm2D-25     [[1, 84, 14, 14]]     [1, 84, 14, 14]          336      
 LinearBottleneck-7   [[1, 72, 14, 14]]     [1, 84, 14, 14]           0       
     Conv2D-30        [[1, 84, 14, 14]]     [1, 504, 14, 14]       42,336     
   BatchNorm2D-26     [[1, 504, 14, 14]]    [1, 504, 14, 14]        2,016     
      Swish-8         [[1, 504, 14, 14]]    [1, 504, 14, 14]          0       
     Conv2D-31        [[1, 504, 14, 14]]    [1, 504, 14, 14]        4,536     
   BatchNorm2D-27     [[1, 504, 14, 14]]    [1, 504, 14, 14]        2,016     
AdaptiveAvgPool2D-5   [[1, 504, 14, 14]]     [1, 504, 1, 1]           0       
     Conv2D-32         [[1, 504, 1, 1]]      [1, 42, 1, 1]         21,210     
   BatchNorm2D-28      [[1, 42, 1, 1]]       [1, 42, 1, 1]           168      
       ReLU-5          [[1, 42, 1, 1]]       [1, 42, 1, 1]            0       
     Conv2D-33         [[1, 42, 1, 1]]       [1, 504, 1, 1]        21,672     
     Sigmoid-5         [[1, 504, 1, 1]]      [1, 504, 1, 1]           0       
        SE-5          [[1, 504, 14, 14]]    [1, 504, 14, 14]          0       
      ReLU6-8         [[1, 504, 14, 14]]    [1, 504, 14, 14]          0       
     Conv2D-34        [[1, 504, 14, 14]]    [1, 95, 14, 14]        47,880     
   BatchNorm2D-29     [[1, 95, 14, 14]]     [1, 95, 14, 14]          380      
 LinearBottleneck-8   [[1, 84, 14, 14]]     [1, 95, 14, 14]           0       
     Conv2D-35        [[1, 95, 14, 14]]     [1, 570, 14, 14]       54,150     
   BatchNorm2D-30     [[1, 570, 14, 14]]    [1, 570, 14, 14]        2,280     
      Swish-9         [[1, 570, 14, 14]]    [1, 570, 14, 14]          0       
     Conv2D-36        [[1, 570, 14, 14]]    [1, 570, 14, 14]        5,130     
   BatchNorm2D-31     [[1, 570, 14, 14]]    [1, 570, 14, 14]        2,280     
AdaptiveAvgPool2D-6   [[1, 570, 14, 14]]     [1, 570, 1, 1]           0       
     Conv2D-37         [[1, 570, 1, 1]]      [1, 47, 1, 1]         26,837     
   BatchNorm2D-32      [[1, 47, 1, 1]]       [1, 47, 1, 1]           188      
       ReLU-6          [[1, 47, 1, 1]]       [1, 47, 1, 1]            0       
     Conv2D-38         [[1, 47, 1, 1]]       [1, 570, 1, 1]        27,360     
     Sigmoid-6         [[1, 570, 1, 1]]      [1, 570, 1, 1]           0       
        SE-6          [[1, 570, 14, 14]]    [1, 570, 14, 14]          0       
      ReLU6-9         [[1, 570, 14, 14]]    [1, 570, 14, 14]          0       
     Conv2D-39        [[1, 570, 14, 14]]    [1, 106, 14, 14]       60,420     
   BatchNorm2D-33     [[1, 106, 14, 14]]    [1, 106, 14, 14]         424      
 LinearBottleneck-9   [[1, 95, 14, 14]]     [1, 106, 14, 14]          0       
     Conv2D-40        [[1, 106, 14, 14]]    [1, 636, 14, 14]       67,416     
   BatchNorm2D-34     [[1, 636, 14, 14]]    [1, 636, 14, 14]        2,544     
      Swish-10        [[1, 636, 14, 14]]    [1, 636, 14, 14]          0       
     Conv2D-41        [[1, 636, 14, 14]]    [1, 636, 14, 14]        5,724     
   BatchNorm2D-35     [[1, 636, 14, 14]]    [1, 636, 14, 14]        2,544     
AdaptiveAvgPool2D-7   [[1, 636, 14, 14]]     [1, 636, 1, 1]           0       
     Conv2D-42         [[1, 636, 1, 1]]      [1, 53, 1, 1]         33,761     
   BatchNorm2D-36      [[1, 53, 1, 1]]       [1, 53, 1, 1]           212      
       ReLU-7          [[1, 53, 1, 1]]       [1, 53, 1, 1]            0       
     Conv2D-43         [[1, 53, 1, 1]]       [1, 636, 1, 1]        34,344     
     Sigmoid-7         [[1, 636, 1, 1]]      [1, 636, 1, 1]           0       
        SE-7          [[1, 636, 14, 14]]    [1, 636, 14, 14]          0       
      ReLU6-10        [[1, 636, 14, 14]]    [1, 636, 14, 14]          0       
     Conv2D-44        [[1, 636, 14, 14]]    [1, 117, 14, 14]       74,412     
   BatchNorm2D-37     [[1, 117, 14, 14]]    [1, 117, 14, 14]         468      
LinearBottleneck-10   [[1, 106, 14, 14]]    [1, 117, 14, 14]          0       
     Conv2D-45        [[1, 117, 14, 14]]    [1, 702, 14, 14]       82,134     
   BatchNorm2D-38     [[1, 702, 14, 14]]    [1, 702, 14, 14]        2,808     
      Swish-11        [[1, 702, 14, 14]]    [1, 702, 14, 14]          0       
     Conv2D-46        [[1, 702, 14, 14]]    [1, 702, 14, 14]        6,318     
   BatchNorm2D-39     [[1, 702, 14, 14]]    [1, 702, 14, 14]        2,808     
AdaptiveAvgPool2D-8   [[1, 702, 14, 14]]     [1, 702, 1, 1]           0       
     Conv2D-47         [[1, 702, 1, 1]]      [1, 58, 1, 1]         40,774     
   BatchNorm2D-40      [[1, 58, 1, 1]]       [1, 58, 1, 1]           232      
       ReLU-8          [[1, 58, 1, 1]]       [1, 58, 1, 1]            0       
     Conv2D-48         [[1, 58, 1, 1]]       [1, 702, 1, 1]        41,418     
     Sigmoid-8         [[1, 702, 1, 1]]      [1, 702, 1, 1]           0       
        SE-8          [[1, 702, 14, 14]]    [1, 702, 14, 14]          0       
      ReLU6-11        [[1, 702, 14, 14]]    [1, 702, 14, 14]          0       
     Conv2D-49        [[1, 702, 14, 14]]    [1, 128, 14, 14]       89,856     
   BatchNorm2D-41     [[1, 128, 14, 14]]    [1, 128, 14, 14]         512      
LinearBottleneck-11   [[1, 117, 14, 14]]    [1, 128, 14, 14]          0       
     Conv2D-50        [[1, 128, 14, 14]]    [1, 768, 14, 14]       98,304     
   BatchNorm2D-42     [[1, 768, 14, 14]]    [1, 768, 14, 14]        3,072     
      Swish-12        [[1, 768, 14, 14]]    [1, 768, 14, 14]          0       
     Conv2D-51        [[1, 768, 14, 14]]     [1, 768, 7, 7]         6,912     
   BatchNorm2D-43      [[1, 768, 7, 7]]      [1, 768, 7, 7]         3,072     
AdaptiveAvgPool2D-9    [[1, 768, 7, 7]]      [1, 768, 1, 1]           0       
     Conv2D-52         [[1, 768, 1, 1]]      [1, 64, 1, 1]         49,216     
   BatchNorm2D-44      [[1, 64, 1, 1]]       [1, 64, 1, 1]           256      
       ReLU-9          [[1, 64, 1, 1]]       [1, 64, 1, 1]            0       
     Conv2D-53         [[1, 64, 1, 1]]       [1, 768, 1, 1]        49,920     
     Sigmoid-9         [[1, 768, 1, 1]]      [1, 768, 1, 1]           0       
        SE-9           [[1, 768, 7, 7]]      [1, 768, 7, 7]           0       
      ReLU6-12         [[1, 768, 7, 7]]      [1, 768, 7, 7]           0       
     Conv2D-54         [[1, 768, 7, 7]]      [1, 140, 7, 7]        107,520    
   BatchNorm2D-45      [[1, 140, 7, 7]]      [1, 140, 7, 7]          560      
LinearBottleneck-12   [[1, 128, 14, 14]]     [1, 140, 7, 7]           0       
     Conv2D-55         [[1, 140, 7, 7]]      [1, 840, 7, 7]        117,600    
   BatchNorm2D-46      [[1, 840, 7, 7]]      [1, 840, 7, 7]         3,360     
      Swish-13         [[1, 840, 7, 7]]      [1, 840, 7, 7]           0       
     Conv2D-56         [[1, 840, 7, 7]]      [1, 840, 7, 7]         7,560     
   BatchNorm2D-47      [[1, 840, 7, 7]]      [1, 840, 7, 7]         3,360     
AdaptiveAvgPool2D-10   [[1, 840, 7, 7]]      [1, 840, 1, 1]           0       
     Conv2D-57         [[1, 840, 1, 1]]      [1, 70, 1, 1]         58,870     
   BatchNorm2D-48      [[1, 70, 1, 1]]       [1, 70, 1, 1]           280      
      ReLU-10          [[1, 70, 1, 1]]       [1, 70, 1, 1]            0       
     Conv2D-58         [[1, 70, 1, 1]]       [1, 840, 1, 1]        59,640     
     Sigmoid-10        [[1, 840, 1, 1]]      [1, 840, 1, 1]           0       
       SE-10           [[1, 840, 7, 7]]      [1, 840, 7, 7]           0       
      ReLU6-13         [[1, 840, 7, 7]]      [1, 840, 7, 7]           0       
     Conv2D-59         [[1, 840, 7, 7]]      [1, 151, 7, 7]        126,840    
   BatchNorm2D-49      [[1, 151, 7, 7]]      [1, 151, 7, 7]          604      
LinearBottleneck-13    [[1, 140, 7, 7]]      [1, 151, 7, 7]           0       
     Conv2D-60         [[1, 151, 7, 7]]      [1, 906, 7, 7]        136,806    
   BatchNorm2D-50      [[1, 906, 7, 7]]      [1, 906, 7, 7]         3,624     
      Swish-14         [[1, 906, 7, 7]]      [1, 906, 7, 7]           0       
     Conv2D-61         [[1, 906, 7, 7]]      [1, 906, 7, 7]         8,154     
   BatchNorm2D-51      [[1, 906, 7, 7]]      [1, 906, 7, 7]         3,624     
AdaptiveAvgPool2D-11   [[1, 906, 7, 7]]      [1, 906, 1, 1]           0       
     Conv2D-62         [[1, 906, 1, 1]]      [1, 75, 1, 1]         68,025     
   BatchNorm2D-52      [[1, 75, 1, 1]]       [1, 75, 1, 1]           300      
      ReLU-11          [[1, 75, 1, 1]]       [1, 75, 1, 1]            0       
     Conv2D-63         [[1, 75, 1, 1]]       [1, 906, 1, 1]        68,856     
     Sigmoid-11        [[1, 906, 1, 1]]      [1, 906, 1, 1]           0       
       SE-11           [[1, 906, 7, 7]]      [1, 906, 7, 7]           0       
      ReLU6-14         [[1, 906, 7, 7]]      [1, 906, 7, 7]           0       
     Conv2D-64         [[1, 906, 7, 7]]      [1, 162, 7, 7]        146,772    
   BatchNorm2D-53      [[1, 162, 7, 7]]      [1, 162, 7, 7]          648      
LinearBottleneck-14    [[1, 151, 7, 7]]      [1, 162, 7, 7]           0       
     Conv2D-65         [[1, 162, 7, 7]]      [1, 972, 7, 7]        157,464    
   BatchNorm2D-54      [[1, 972, 7, 7]]      [1, 972, 7, 7]         3,888     
      Swish-15         [[1, 972, 7, 7]]      [1, 972, 7, 7]           0       
     Conv2D-66         [[1, 972, 7, 7]]      [1, 972, 7, 7]         8,748     
   BatchNorm2D-55      [[1, 972, 7, 7]]      [1, 972, 7, 7]         3,888     
AdaptiveAvgPool2D-12   [[1, 972, 7, 7]]      [1, 972, 1, 1]           0       
     Conv2D-67         [[1, 972, 1, 1]]      [1, 81, 1, 1]         78,813     
   BatchNorm2D-56      [[1, 81, 1, 1]]       [1, 81, 1, 1]           324      
      ReLU-12          [[1, 81, 1, 1]]       [1, 81, 1, 1]            0       
     Conv2D-68         [[1, 81, 1, 1]]       [1, 972, 1, 1]        79,704     
     Sigmoid-12        [[1, 972, 1, 1]]      [1, 972, 1, 1]           0       
       SE-12           [[1, 972, 7, 7]]      [1, 972, 7, 7]           0       
      ReLU6-15         [[1, 972, 7, 7]]      [1, 972, 7, 7]           0       
     Conv2D-69         [[1, 972, 7, 7]]      [1, 174, 7, 7]        169,128    
   BatchNorm2D-57      [[1, 174, 7, 7]]      [1, 174, 7, 7]          696      
LinearBottleneck-15    [[1, 162, 7, 7]]      [1, 174, 7, 7]           0       
     Conv2D-70         [[1, 174, 7, 7]]     [1, 1044, 7, 7]        181,656    
   BatchNorm2D-58     [[1, 1044, 7, 7]]     [1, 1044, 7, 7]         4,176     
      Swish-16        [[1, 1044, 7, 7]]     [1, 1044, 7, 7]           0       
     Conv2D-71        [[1, 1044, 7, 7]]     [1, 1044, 7, 7]         9,396     
   BatchNorm2D-59     [[1, 1044, 7, 7]]     [1, 1044, 7, 7]         4,176     
AdaptiveAvgPool2D-13  [[1, 1044, 7, 7]]     [1, 1044, 1, 1]           0       
     Conv2D-72        [[1, 1044, 1, 1]]      [1, 87, 1, 1]         90,915     
   BatchNorm2D-60      [[1, 87, 1, 1]]       [1, 87, 1, 1]           348      
      ReLU-13          [[1, 87, 1, 1]]       [1, 87, 1, 1]            0       
     Conv2D-73         [[1, 87, 1, 1]]      [1, 1044, 1, 1]        91,872     
     Sigmoid-13       [[1, 1044, 1, 1]]     [1, 1044, 1, 1]           0       
       SE-13          [[1, 1044, 7, 7]]     [1, 1044, 7, 7]           0       
      ReLU6-16        [[1, 1044, 7, 7]]     [1, 1044, 7, 7]           0       
     Conv2D-74        [[1, 1044, 7, 7]]      [1, 185, 7, 7]        193,140    
   BatchNorm2D-61      [[1, 185, 7, 7]]      [1, 185, 7, 7]          740      
LinearBottleneck-16    [[1, 174, 7, 7]]      [1, 185, 7, 7]           0       
     Conv2D-75         [[1, 185, 7, 7]]     [1, 1280, 7, 7]        236,800    
   BatchNorm2D-62     [[1, 1280, 7, 7]]     [1, 1280, 7, 7]         5,120     
      Swish-17        [[1, 1280, 7, 7]]     [1, 1280, 7, 7]           0       
AdaptiveAvgPool2D-14  [[1, 1280, 7, 7]]     [1, 1280, 1, 1]           0       
     Dropout-1        [[1, 1280, 1, 1]]     [1, 1280, 1, 1]           0       
     Conv2D-76        [[1, 1280, 1, 1]]      [1, 10, 1, 1]         12,810     
================================================================================
Total params: 3,570,061
Trainable params: 3,487,305
Non-trainable params: 82,756
--------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 179.95
Params size (MB): 13.62
Estimated Total Size (MB): 194.15
--------------------------------------------------------------------------------






{'total_params': 3570061, 'trainable_params': 3487305}

3. 模型训练

model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters()),
              loss=paddle.nn.CrossEntropyLoss(),
              metrics=paddle.metric.Accuracy())


model.fit(
    train_data=train_dataset, 
    eval_data=val_dataset, 
    batch_size=128, 
    epochs=10, 
    verbose=1, 
addle.metric.Accuracy())


model.fit(
    train_data=train_dataset, 
    eval_data=val_dataset, 
    batch_size=128, 
    epochs=10, 
    verbose=1, 
)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/10


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")


step 391/391 [==============================] - loss: 1.4491 - acc: 0.3863 - 575ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.7882 - acc: 0.4751 - 456ms/step         
Eval samples: 10000
Epoch 2/10
step 391/391 [==============================] - loss: 1.1541 - acc: 0.5422 - 581ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.2044 - acc: 0.5535 - 452ms/step         
Eval samples: 10000
Epoch 3/10
step 391/391 [==============================] - loss: 0.9998 - acc: 0.6278 - 580ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.9721 - acc: 0.6489 - 445ms/step         
Eval samples: 10000
Epoch 4/10
step 391/391 [==============================] - loss: 0.9672 - acc: 0.6923 - 580ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.9084 - acc: 0.6742 - 463ms/step         
Eval samples: 10000
Epoch 5/10
step 391/391 [==============================] - loss: 0.7523 - acc: 0.7179 - 589ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.0783 - acc: 0.7013 - 457ms/step         
Eval samples: 10000
Epoch 6/10
step 391/391 [==============================] - loss: 0.5859 - acc: 0.7411 - 586ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.7804 - acc: 0.7353 - 448ms/step         
Eval samples: 10000
Epoch 7/10
step 391/391 [==============================] - loss: 0.9060 - acc: 0.7618 - 591ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.1076 - acc: 0.7308 - 454ms/step         
Eval samples: 10000
Epoch 8/10
step 391/391 [==============================] - loss: 0.5531 - acc: 0.7673 - 592ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.9240 - acc: 0.7258 - 459ms/step         
Eval samples: 10000
Epoch 9/10
step 391/391 [==============================] - loss: 0.6456 - acc: 0.7740 - 594ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.6419 - acc: 0.7475 - 461ms/step         
Eval samples: 10000
Epoch 10/10
step 391/391 [==============================] - loss: 0.4422 - acc: 0.8106 - 597ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.6156 - acc: 0.7785 - 469ms/step         
Eval samples: 10000

总结

  • 本文主要目的是通过对MoblieNet进行小改动,以减少表征瓶颈

  • 本文在ImageNet性能达到77.9

  • 本文美不中足的地方是,作者在训练ReXNet时候用了很多的trick,实际上如果不用各种trick,不用预训练模型,Paddle内置MoblieNet V2拟合能力更好,收敛更快

  • ReXNet推理阶段速度比同FLOPs的MoblieNet V2-1.2要慢,这是因为网络架构问题,本文亮点主要是提出一些设计原则,鼓励NAS搜索更好的网络

潮生灬
关注 关注
  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
网络解读/论文笔记-ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network
XZH16047129的博客
08-16 1870
目录 1.MobileNetV1 特征 优点:Depthwise Convolution 网络结构 2.MobileNet v2 改进原因 特征: 优点:Inverted residual block 网络结构图 关于shortcut 3.ReXNet 1.MobileNetV1 特征 专注于移动端或者嵌入式设备中的轻量级CNN网络。相比传统卷积神经网络,在准确率小幅降低的前提下大大减少模型参数与运算量。(相比VGG16准确率减少了0.9%,但模型参数只有VGG的1/32)。 .
ReXNet学习笔记 --- ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network
梦坠凡尘
07-23 2106
论文:https://arxiv.org/pdf/2007.00992.pdf 代码:https://github.com/clovaai/rexnet 挺久没有看新论文了,,只有学习才能让我快乐!! 摘要 本文解决了网络中的representational bottleneck,并提出了一组可显著提高模型性能的设计原则。作者认为,representational bottleneck可能会出现在由常规设计设计网络中,并导致模型性能下降。为了研究representational bottleneck,作者
探索高效模型设计的新路径:ReXNet
最新发布
gitblog_00050的博客
05-08 330
探索高效模型设计的新路径:ReXNet 项目地址:https://gitcode.com/clovaai/rexnet深度学习领域,设计一款在有限计算资源下仍能保持高精度的模型是一项艰巨挑战。近期,NAVER AI Lab的研究人员提出了一种名为ReXNet的新型网络结构,它打破了传统的通道维度配置规则,实现了效率与性能的双重提升。这一创新成果已被接受并在CVPR 2021大会上发表。 简介 ...
rexnet:带有预训练模型的ReXNet(等级扩展网络)的官方Pytorch实施
03-18
(注意)我们的论文已被CVPR 2021接受!!提交的论文将在arxiv上更新! (注意)性能优于EfficientNet-Lite的新型号ReXNet-Lite即将上载! ReXNet:减少卷积神经网络上的代表性瓶颈 韩东运,尹尚道,许炳镐和杨俊J || 海军实验室AI LAB 抽象的 本文解决了网络中的代表性瓶颈,并提出了一组可显着提高模型性能的设计原则。我们认为,代表性的瓶颈可能会出现在由常规设计设计网络中,并导致模型性能下降。为了研究代表性瓶颈,我们研究了由上万个随机网络生成的特征的矩阵等级。我们将进一步研究整个层的通道配置,以设计更准确的网络体系结构。基于调查,我们提出了简单而有效的设计原则来减轻代表性瓶颈。遵循该原则对基准网络进行的细微更改将导致ImageNet分类的显着性能改进。此外,COCO对象检测结果和多个数据集上的转移学习结果还为减少网络的代表性瓶颈与提高性能之间
RexNet片段记录
weixin_46178977的博客
11-13 399
ResNet 片段理解记录
简易计算机c#源码 用了简单工厂模式
01-09
自己编的能够实现加减乘除的计算器源代码,用了简单工厂模式
python自定义函数def的应用详解
12-17
这里是三岁,来和大家唠唠自定义函数,这一个神奇的东西,带大家白话玩转自定义函数 自定义函数,编程里面的精髓! def 自定义函数的必要函数:def 使用方法: def 函数名(参数1,参数2,参数…): 函数体(语句...
其实你不算是运营经理,仅仅是运行经理而已
02-26
大概就是他们领导看了我写的“深入浅出用户运营”(历史记录中查看)那篇文章之后,分享给了他们运营部门,并下达了指令,非让运营部门派个人来拜访我唠唠运营那点。他们领导不知道从哪打听出来的说韩利好色,必须...
iOS开发第三方键盘处理实例代码
01-04
最近项目中遇到了键盘处理通知被调用多次的情况,废了好半天时间才找到解决办法,今天就给小伙伴儿们唠唠第三方键盘处理的那些坑! 详情请看:『https://github.com/boai/BAKeyboardDemo』 ! 1、聊天评论框的封装 ...
搭建一个ReXNet网络,写出完整代码
weixin_42600128的博客
02-18 108
首先,我们需要导入必要的库:import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader from torchvision.models import ReXNet from torch import nn, optim 接下来,我们需要定义数据预...
目标检测ReXnet:Rethinking Channel Dimensions for Efficient Model Design
qq_41950533的博客
03-25 4198
摘要 cvpr2021 作者从秩的角度出发,改进了网络。作者提出的一个设计原则: 1.知识储备 扩张层: 如果某个层的输出通道数(秩)大于输入通道数(秩) 收缩层: 如果某个层的输出通道数小于输入通道数 秩秩是图像经过矩阵变换之后的空间维度 秩是列空间的维度。维度是完全正交(独立)特征种类的个数。秩越高的矩阵内容越丰富,冗余信息越少。秩越低的矩阵废数据越多。 Softmax Bottleneck: 分类函数是 损失函数是,是交叉熵函数,其取值范围会受输入的影响。 2.当前网络的问题 2.1扩张层中的表达
2021-03-16
weixin_40688204的博客
03-16 166
CVPR 2021 | 全新Backbone!ReXNet:助力CV任务全面涨点 CVer今天 以下文章来源于AI人工智能初学者,作者ChaucerG AI人工智能初学者 机器学习知识点总结、深度学习知识点总结以及相关垂直领域的跟进,比如CV,NLP等方面的知识。 点击下方卡片,关注“CVer”公众号 AI/CV重磅干货,第一时间送达 CVer 一个专注侃侃计算机视觉方向的公众号。计算机视觉、图像处理、机器学习、深度学习、C/C++、Python、诗和远方等。 193篇原创内容..
ResNet系列及其变体(一)—ResNetv1
Moeyinss
05-23 1947
ReXNet模型pth转onnx,swish不支持问题
风华正茂的博客
10-19 859
ReXNet模型转onnx,swish不支持问题
附代码 ReXNet:重新考虑高效模型设计中的通道尺寸
weixin_44543648的博客
12-11 2843
Rethinking Channel Dimensions for Efficient Model Design论文解读 一个轻量级模型的精度进一步受到了设计惯例的限制:通道维度的阶段配置,它看起来像一个网络阶段的分段线性函数。在本文的研究中,我们研究了一种有效的通道尺寸配置。为此,我们通过分析输出特征的秩,实证研究了如何正确设计单个层。然后,我们通过搜索在计算成本限制下有关信道配置的网络体系结构来研究模型的通道配置。
完整学习 ResNet 家族 ResNext, SEResNet, SEResNext 代码实现- part2
史蒂夫方
10-13 4426
我的更新一向缓慢 因为实在太忙碌了, 然后写这些笔记主要也是希望要自己以及看的人都能学到东西, 我写的文章只要你认真的从头看到尾一定有收获, 每个知识点能讲齐的一定会说明白,要是不行, 也会找个链接补充的 好了屁话少说 这篇延续上一篇介绍的ResNet, 来说一下ResNext吧 ResNext主要从ResNet的网络做了一些变化, 老样子先从理论在从代码上说会更清楚 如果不清楚ResNet的结构...
手势识别(一) - 项目概述与简单应用介绍
iamfishman
12-20 5252
我公司的科室开始在公众号上规划一些对外的技术文章了,包括实战项目、模型优化、端侧部署和一些深度学习任务基础知识,而我负责人体图象相关技术这一系列文章,偶尔也会出一些应用/代码解读等相关的文章。 文章在同步发布至公众号和博客,顺带做一波宣传。有兴趣的还可以扫码加入我们的群。 (文章有写的不好的地方请见谅,另外有啥错误的地方也请大家帮忙指出。) (另外,文章引用的图片or代码如有侵权,请联系我删除。) 微信公众号:AI炼丹术 技术交流群可以从公众号上获取,可以备注是咸鱼的博客上来的。???? 【手把手教学】手
ResNet网络结构搭建
m0_56247038的博客
06-10 5831
model.py train.py ResNet网络结构详解与模型的搭建_太阳花的小绿豆的博客-CSDN博客_resnet网络结构
Qt学习: QCloseEvent关闭件的使用及代码示例
06-07
QCloseEvent是Qt中的一个件类,用于在窗口关闭之前执行一些操作。可以通过重载QWidget或QMainWindow的closeEvent()函数来处理QCloseEvent件。 下面是一个简单的示例代码,演示如何在关闭窗口之前弹出一个确认对话框: ```cpp void MyWidget::closeEvent(QCloseEvent *event) { QMessageBox::StandardButton button; button = QMessageBox::question(this, tr("退出程序"), tr("确定退出吗?"), QMessageBox::Yes | QMessageBox::No); if (button == QMessageBox::Yes) { event->accept(); } else { event->ignore(); } } ``` 在上面的代码中,我们通过question()方法生成一个确认对话框,询问用户是否退出程序。如果用户点击“Yes”按钮,我们通过调用accept()方法接受QCloseEvent件,关闭窗口;否则,我们通过调用ignore()方法忽略QCloseEvent件,继续保持窗口打开状态。 希望这个简单的示例代码能够帮助你学习QCloseEvent件的使用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
写文章

热门文章

  • 比ResNet更强的RepVGG代码详解 9967
  • ReXNet:消除表达瓶颈,进来唠唠网络设计那些事 2107
  • 【自监督】浅析 MaskFeat 910

最新评论

  • ReXNet:消除表达瓶颈,进来唠唠网络设计那些事

    Ivan.1: 请问表格里的SE是代表什么吗

  • 比ResNet更强的RepVGG代码详解

    zhaodashuai123: 我看了你写的代码,想确认一下,是不是在最后部署的时候才结构重参数化,而并不是每次测试都要结构重参数化

  • 比ResNet更强的RepVGG代码详解

    J y w: 我想问一下你用的是官方代码嘛

  • 比ResNet更强的RepVGG代码详解

    Auraro98: 你好,我可以请教一下repvgg替换yolo3的相关问题吗?最近正好在尝试,有些疑惑

  • 比ResNet更强的RepVGG代码详解

    Taylor不想被展开: 大佬,在官方代码的train.py中他们自己定义了一个sgd_optimizer()函数,它和普通的torch.opt.SGD什么区别呀?

大家在看

  • Coursera耶鲁大学金融课程:Financial Markets 笔记Week 02 686
  • 基于STM32和人工智能的智能四轴飞行器系统 669
  • 11075 强盗分赃 26
  • Python代码测试 148
  • 如何在MySql数据库中以经纬度进行查询 982

最新文章

  • 【自监督】浅析 MaskFeat
  • 比ResNet更强的RepVGG代码详解
2021年3篇

目录

目录

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43元 前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值

深圳SEO优化公司晋中关键词排名价格海口网站建设设计榆林网站优化按天收费清徐网站优化按天计费报价鹤壁网站搜索优化推荐爱联百度关键词包年推广多少钱阿里seo排名公司兴安盟百姓网标王公司张家界网站优化推广报价商丘网站改版崇左网站优化按天扣费推荐莆田关键词按天扣费海北品牌网站设计哪家好钦州网站推广系统公司张家口网站关键词优化报价河池网站seo优化推荐朝阳SEO按天计费哪家好白银关键词排名包年推广价格那曲百姓网标王公司四平seo报价濮阳企业网站建设推荐坑梓SEO按效果付费多少钱本溪品牌网站设计报价阿里网站推广工具推荐烟台网站优化推广价格临夏网站优化按天计费公司菏泽营销网站报价漳州百度竞价包年推广多少钱昭通网站优化按天扣费济源百姓网标王推广报价歼20紧急升空逼退外机英媒称团队夜以继日筹划王妃复出草木蔓发 春山在望成都发生巨响 当地回应60岁老人炒菠菜未焯水致肾病恶化男子涉嫌走私被判11年却一天牢没坐劳斯莱斯右转逼停直行车网传落水者说“没让你救”系谣言广东通报13岁男孩性侵女童不予立案贵州小伙回应在美国卖三蹦子火了淀粉肠小王子日销售额涨超10倍有个姐真把千机伞做出来了近3万元金手镯仅含足金十克呼北高速交通事故已致14人死亡杨洋拄拐现身医院国产伟哥去年销售近13亿男子给前妻转账 现任妻子起诉要回新基金只募集到26元还是员工自购男孩疑遭霸凌 家长讨说法被踢出群充个话费竟沦为间接洗钱工具新的一天从800个哈欠开始单亲妈妈陷入热恋 14岁儿子报警#春分立蛋大挑战#中国投资客涌入日本东京买房两大学生合买彩票中奖一人不认账新加坡主帅:唯一目标击败中国队月嫂回应掌掴婴儿是在赶虫子19岁小伙救下5人后溺亡 多方发声清明节放假3天调休1天张家界的山上“长”满了韩国人?开封王婆为何火了主播靠辱骂母亲走红被批捕封号代拍被何赛飞拿着魔杖追着打阿根廷将发行1万与2万面值的纸币库克现身上海为江西彩礼“减负”的“试婚人”因自嘲式简历走红的教授更新简介殡仪馆花卉高于市场价3倍还重复用网友称在豆瓣酱里吃出老鼠头315晚会后胖东来又人满为患了网友建议重庆地铁不准乘客携带菜筐特朗普谈“凯特王妃P图照”罗斯否认插足凯特王妃婚姻青海通报栏杆断裂小学生跌落住进ICU恒大被罚41.75亿到底怎么缴湖南一县政协主席疑涉刑案被控制茶百道就改标签日期致歉王树国3次鞠躬告别西交大师生张立群任西安交通大学校长杨倩无缘巴黎奥运

深圳SEO优化公司 XML地图 TXT地图 虚拟主机 SEO 网站制作 网站优化