当前位置: 首页 > news >正文

深度学习(视觉注意力SeNet/CbmaNet/SkNet/EcaNet)

这些网络提供了一种即插即用的注意力模块,可以嵌入到现有的主流架构(如ResNet, VGG, MobileNet等)中,带来几乎无成本的性能提升。

四种网络核心思想:

1. SENet (Squeeze-and-Excitation Network): 通道注意力(Channel Attention)。专注于建模通道之间的相互依赖关系,自动学习到每个通道的重要程度,然后为重要的通道赋予更大的权重。

2. CBAM (Convolutional Block Attention Module): 通道注意力 + 空间注意力 的串联结构。认为只关注通道维度是不够的,空间位置上的信息也同样重要。CBAM依次从通道和空间两个维度计算注意力图。

3. SKNet (Selective Kernel Networks): 动态选择不同大小的卷积核(感受野)。让网络能够根据输入信息的复杂程度,自适应地调节其感受野的大小。

4. ECANet (Efficient Channel Attention Network): 对SENet的轻量化和改进。认为SENet中的降维操作对通道注意力预测会产生副作用,并且两个全连接层显得笨重。ECANet提出了一种不降维的、更高效的局部跨通道交互策略。

总结与对比:

网络核心思想注意力维度主要特点优点缺点
SENet 通道重要性 通道 GAP + FC + Sigmoid 开创性强,即插即用,效果显著 降维可能破坏通道关系,有参数量
CBAM 通道+空间重要性 通道 & 空间 GAP+GMP → MLP;通道池化 → Conv 注意力更全面,效果通常优于SENet 顺序结构可能非最优
SKNet 动态选择感受野 尺度/核 多分支卷积,自适应加权融合 自适应能力强,多尺度性能优异 计算和参数量较大
ECANet 高效通道交互 通道 GAP + 1DConv + Sigmoid 极其轻量,无降维,效率极高 仅通道维度

代码如下:

import torch
import torch.nn as nnclass SeNet(nn.Module):def __init__(self, inchannel, ratio=16):super(SeNet, self).__init__()self.gap = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(inchannel, inchannel // ratio, bias=False),  # 从 c -> c/rnn.ReLU(inplace=True),nn.Linear(inchannel // ratio, inchannel, bias=False),  # 从 c/r -> c
            nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.gap(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)class CbamNet(nn.Module):def __init__(self, channels, reduction=16):super(CbamNet, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,padding=0)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,padding=0)self.sigmoid_channel = nn.Sigmoid()self.conv_after_concat = nn.Conv2d(2,1,kernel_size=3,stride=1,padding=1)self.sigmoid_spatial = nn.Sigmoid()def forward(self, x):# avg全局池化+MLPavg = self.avg_pool(x)   avg = self.fc1(avg)    avg = self.relu(avg)  avg = self.fc2(avg)  # max全局池化+MLPmx = self.max_pool(x)  mx = self.fc1(mx)  mx = self.relu(mx) mx = self.fc2(mx) x = x * self.sigmoid_channel(avg+mx)module_input = x avg = torch.mean(x, 1, True)mx, _ = torch.max(x, 1, True)x = torch.cat((avg, mx), 1)x = self.conv_after_concat(x)x = self.sigmoid_spatial(x)x = module_input * xreturn xclass SkNet(nn.Module):def __init__(self,inchannel,ratio=16):super(SkNet,self).__init__()self.conv3x3 = nn.Conv2d(inchannel,inchannel,kernel_size=3,dilation=1, padding=1)self.conv5x5 = nn.Conv2d(inchannel,inchannel,kernel_size=3,dilation=2, padding=2)self.avg = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(inchannel, inchannel // ratio, bias=False),  # 从 c -> c/rnn.ReLU(inplace=True),nn.Linear(inchannel // ratio, inchannel*2, bias=False)  # 从 c/r -> c
        )self.softmax = nn.Softmax(dim=1)def forward(self,x):x1 = self.conv3x3(x)x2 = self.conv5x5(x)z = x1 + x2B, C, _, _ = z.size()z = self.avg(z).view(B, C)z = self.fc(z)z = z.view(B, 2, C) a = z[:, 0, :].unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]b = z[:, 1, :].unsqueeze(-1).unsqueeze(-1)x1 = x1 * ax2 = x2 * bx = x1 + x2return xclass EcaNet(nn.Module):def __init__(self,k_size=3):super(EcaNet, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)se_model = SeNet(128)
cbma_model = CbamNet(128,16)
sk_model = SkNet(128,16)
eca_model = EcaNet()x = torch.randn([1,128,200,200])
# y = cbma_model(x)

torch.onnx.export(se_model,x,'se_net.onnx',opset_version=11)
torch.onnx.export(cbma_model,x,'cbma_net.onnx',opset_version=11)
torch.onnx.export(sk_model,x,'sk_net.onnx',opset_version=11)
torch.onnx.export(eca_model,x,'eca_net.onnx',opset_version=11)
http://www.hskmm.com/?act=detail&tid=10480

相关文章:

  • 起床
  • qoj6277 Linear Congruential Generator
  • docker+k8s
  • 多模型适配突围:JBoltAI如何重构企业数智化转型新范式?
  • JBoltAI赋能制造业数智化转型:AI从概念到落地的Java实践
  • JBoltAI赋能医疗数智化转型:AI大模型如何重塑医疗健康新范式
  • JBoltAI多模态赋能:制造业数智化升级的新引擎
  • 深入解析:YARN架构解析:深入理解Hadoop资源管理核心
  • JBoltAI:破解Java企业级AI应用落地难题的利器
  • 直播软件开发,单例设计模式很简单吗? - 云豹科技
  • Java开发者的AI革命:如何用JBoltAI应对数智化转型挑战
  • JBoltAI:赋能Java老项目快速接入AI能力的创新之道
  • Day04 C:\Users\Lenovo\Desktop\note\code\JavaSE\Basic\src\com\David\operator Demo01-08+Doc
  • 实用指南:养老专业实训室建设方案的分级设计与人才培养适配
  • 物业企业绩效考核制度与考核体系 - 指南
  • Java开发生态的数智化升级:JBoltAI如何重塑企业AI应用架构
  • Mapper.xml与数据库进行映射的sql语言注意事项
  • 直播软件搭建,如何实现伪分布式平台部署? - 云豹科技
  • 初步研究vivio的互传的备份数据格式
  • 完整教程:C#.NetCore NPOI 导出excel 单元格内容换行
  • resultMap和resultType
  • 直播软件怎么开发,自适应两栏布局方式 - 云豹科技
  • resultMap和自定义映射结果形式(ResultMapManage)以及ResultMap Vs ResultType
  • 嵌入式设备不能正常上网问题
  • 2、论文固定模板(背景过度结尾)
  • go: 图片文件上传
  • go: 生成缩略图
  • git: 报错: fatal: 协议错误:错误的行长度字符串:This 或 fatal: protocol error: bad line length character: This
  • jquery: Justified gallery
  • 安装crmeb