【Block总结】PSA,金字塔挤压注意力,解决传统注意力机制在捕获多尺度特征时的局限性

news/2025/2/8 15:18:35 标签: 人工智能, 计算机视觉, transformer, 分类

论文信息

  • 标题: EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network
  • 论文链接: arXiv
  • GitHub链接: https://github.com/murufeng/EPSANet

在这里插入图片描述

创新点

EPSANet提出了一种新颖的金字塔挤压注意力(PSA)模块,旨在解决传统注意力机制在捕获多尺度特征时的局限性。其主要创新点包括:

  1. 高效性: PSA模块能够在较低的计算成本下有效提取多尺度空间信息。
  2. 灵活性: 该模块可以作为即插即用的组件,轻松集成到现有的卷积神经网络(CNN)架构中。
  3. 多尺度特征表示: EPSANet通过自适应地重新校准跨维通道的注意权重,增强了特征表示能力。

方法

EPSANet的核心是PSA模块,其实现过程如下:

  1. 多尺度特征提取: 通过Squeeze and Concat (SPC)模块获得通道维度上的多尺度特征图。
  2. 注意力计算: 使用SEWeight模块提取不同尺度特征图的注意力,生成通道方向的注意力向量。
  3. 再校准: 通过Softmax对通道维度的注意向量进行再校准,得到多尺度信道的再校准权重。
  4. 特征融合: 在重新校准的权重和对应的特征图上进行按元素乘积,输出丰富的多尺度特征图。
    在这里插入图片描述

效果

EPSANet在多个计算机视觉任务中表现出色,包括图像分类、目标检测和实例分割。与传统的通道注意力方法相比,EPSANet在性能上有显著提升。例如:

  • 在ImageNet数据集上,EPSANet的Top-1准确率比SENet-50提高了1.93%。
  • 在MS-COCO数据集上,使用Mask-RCNN时,EPSANet在目标检测和实例分割任务中分别提高了2.7和1.7的AP值。

实验结果

实验结果表明,EPSANet在多个标准数据集上均超越了当前最新的技术。具体表现为:

  • 在COCO 2017数据集上,EPSANet的目标检测率超越了ECANet,AP75指标提高了1.4%。
  • 论文中进行了大量的定性和定量实验,验证了EPSANet在图像分类、目标检测和实例分割方面的先进性能。

总结

EPSANet通过引入金字塔挤压注意力模块,成功地提升了卷积神经网络在多尺度特征提取方面的能力。其灵活的设计使得EPSANet能够广泛应用于各种计算机视觉任务,展现出良好的泛化性能和高效性。该研究为未来的深度学习模型设计提供了新的思路和方法。

代码

import torch
import torch.nn as nn

class SEWeightModule(nn.Module):

    def __init__(self, channels, reduction=16):
        super(SEWeightModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)

        return weight

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    """standard convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class PSAModule(nn.Module):

    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):
        super(PSAModule, self).__init__()
        self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,
                            stride=stride, groups=conv_groups[0])
        self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,
                            stride=stride, groups=conv_groups[1])
        self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,
                            stride=stride, groups=conv_groups[2])
        self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,
                            stride=stride, groups=conv_groups[3])
        self.se = SEWeightModule(planes // 4)
        self.split_channel = planes // 4
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x1 = self.conv_1(x)
        x2 = self.conv_2(x)
        x3 = self.conv_3(x)
        x4 = self.conv_4(x)

        feats = torch.cat((x1, x2, x3, x4), dim=1)
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])

        x1_se = self.se(x1)
        x2_se = self.se(x2)
        x3_se = self.se(x3)
        x4_se = self.se(x4)

        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_weight = feats * attention_vectors
        for i in range(4):
            x_se_weight_fp = feats_weight[:, i, :, :]
            if i == 0:
                out = x_se_weight_fp
            else:
                out = torch.cat((x_se_weight_fp, out), 1)

        return out

if __name__ == "__main__":
    dim=512
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels,height, width)
    x = torch.randn(1,dim,14,14).to(device)
    # 初始化 PSAModule 模块

    block = PSAModule(dim,dim) # kernel_size为height或者width
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

输出结果:
在这里插入图片描述


http://www.niftyadmin.cn/n/5845015.html

相关文章

【学习总结|DAY036】Vue工程化+ElementPlus

引言 在前端开发领域,Vue 作为一款流行的 JavaScript 框架,结合 ElementPlus 组件库,为开发者提供了强大的构建用户界面的能力。本文将结合学习内容,详细介绍 Vue 工程化开发流程以及 ElementPlus 的使用,助力开发者快…

Python3+Request+Pytest+Allure+Jenkins 接口自动化测试[手动写的和AI写的对比]

我手动写的参考 总篇:Python3+Request+Pytest+Allure+Jenkins接口自动化框架设计思路_jenkins python3+request-CSDN博客 https://blog.csdn.net/fen_fen/article/details/144269072 下面是AI写的:Python3+Request+Pytest+Allure+Jenkins 接口自动化测试[AI文章框架] 在软…

深度学习01 神经网络

目录 神经网络 ​感知器 感知器的定义 感知器的数学表达 感知器的局限性 多层感知器(MLP, Multi-Layer Perceptron) 多层感知器的定义 多层感知器的结构 多层感知器的优势 偏置 偏置的作用 偏置的数学表达 神经网络的构造 ​神经网络的基本…

3.攻防世界 Confusion1(服务器模板注入SSTI)

题目描述如下 进入题目页面如下 图片是蟒蛇、大象?python、php? 猜测需要代码审计 点击 F12查看源码,有所提示flag 但是也没有其他信息了 猜测本题存在SSTI(服务器模板注入)漏洞,为验证,构造…

SQL 秒变 ER 图 sql转er图

🚀SQL 秒变 ER 图,校园小助手神了! 学数据库的宝子们集合🙋‍♀️ 是不是每次碰到 SQL 转 ER 图就头皮发麻?看着密密麻麻的代码,脑子直接死机,好不容易理清一点头绪,又被复杂的表关…

B站自研的第二代视频连麦系统(上)

导读 本系列文章将从客户端、服务器以及音视频编码优化三个层面,介绍如何基于WebRTC构建视频连麦系统。希望通过这一系列的讲解,帮助开发者更全面地了解 WebRTC 的核心技术与实践应用。 背景 在文章《B站在实时音视频技术领域的探索与实践》中&#xff…

【LeetCode力扣】1.(简单)两数之和(JavaScript)

两数之和: 题目描述: 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重…

【人工智能】使用Python实现图像风格迁移:理论、算法与实践

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 图像风格迁移是一种利用深度学习技术将一张图像的内容与另一张图像的风格相结合的技术。本文将深入探讨图像风格迁移的基本理论和实现方法,…