Transformer中的嵌入位置编码

news/2025/2/8 14:52:46 标签: transformer, 深度学习, 人工智能

在Transformer中,使用余弦编码或其他类似的编码方式(如正弦-余弦位置编码)而不是简单的“0123456”这种数字编码,主要是因为位置编码的目标是为模型提供位置信息,同时又不引入过多的显式顺序假设。

主要原因如下:

  1. 避免数字编码的离散性: 如果使用简单的数字编码(如0, 1, 2, 3, …),这种编码方法会暗示数字之间有某种固定的数学关系,而实际上,位置之间的关系是相对的而非线性的。例如,在“0123456”这种编码下,位置1和位置2之间的差异与位置6和位置7之间的差异是相同的,但它们在语义上可能并不等价。余弦编码则避免了这种线性关系,它能以周期性的方式映射每个位置,使得模型能够更灵活地学习不同位置之间的关系。

  2. 周期性特性: 余弦编码是基于正弦和余弦函数的,具有自然的周期性。这对于处理循环性质或长距离依赖(如句子的开始和结束、或长期的语法结构)尤其有效。余弦编码的这种周期性特性使得模型可以捕捉到位置之间的相对关系,不管它们之间的距离有多远。

  3. 无固定顺序假设: Transformer的自注意力机制(Self-Attention)是无序的,也就是说,模型本身并不假设序列的顺序是固定的,位置编码的引入是为了补充这个信息。如果使用“0123456”这种数字编码,模型可能会学习到数字之间的大小顺序,而这是不必要的,尤其在处理长文本时,顺序本身的编码可能会导致模型偏向某些固定模式。使用余弦编码则帮助模型从不同的角度(不同的频率)感知位置关系,而不是仅仅依赖于线性编码。

  4. 高效的学习能力: 余弦编码和正弦编码的连续性使得模型在学习时可以更容易地捕捉位置之间的相对关系,特别是对于不同长度的序列。它们的频率变化也帮助模型将不同位置的编码“拉开”距离,使得模型能够更清晰地区分不同的位置信息。

import torch
import math

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # 创建一个位置编码矩阵,大小为 max_len x d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))  # shape: (d_model / 2,)
        
        # 计算每个位置的编码
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置
        
        pe = pe.unsqueeze(0)  # 增加batch维度,shape: (1, max_len, d_model)
        
        self.register_buffer('pe', pe)  # 注册为buffer,不会更新
        
    def forward(self, x):
        # x: 输入的张量,shape: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.size(1)].detach()

# 示例
d_model = 512  # 嵌入维度
max_len = 60   # 序列最大长度
position_encoding = PositionalEncoding(d_model, max_len)

# 假设输入一个批次的序列,shape为 (batch_size, seq_len, d_model)
batch_size = 32
seq_len = 50
x = torch.randn(batch_size, seq_len, d_model)

# 加入位置编码
x_pos = position_encoding(x)
print(x_pos.shape)  # 应该是 (batch_size, seq_len, d_model)

余弦编码和其他类似的连续位置编码方式提供了一个能够捕捉更复杂的位置信息的机制,而简单的数字编码往往过于局限,难以适应Transformer模型的自注意力机制及其灵活的处理能力。


http://www.niftyadmin.cn/n/5844985.html

相关文章

SSH工具之MobaXterm

视频介绍 系统运维之SSH工具 MobaXterm 图文教程 下载MobaXterm MobaXterm下载地址:https://mobaxterm.mobatek.net/download-home-edition.html 根据需求选择便携版(Portable)或者安装版(Installer)。 生成注册文件…

拆解Kotlin中的by lazy:从语法糖到底层实现

by lazy 是Kotlin中一个强大的属性委托机制,它主要用于实现属性的延迟初始化。所谓延迟初始化,就是在第一次访问该属性时才进行初始化,而不是在对象创建时就立即初始化。这种机制在很多场景下都能带来性能优势,特别是当属性的初始…

【Linux网络编程】之配置阿里云安全组

【Linux网络编程】之配置阿里云安全组 配置阿里云安全组阿里云安全组的概念配置安全组规则入方向基本概念补充ICMP协议安全组配置UDP协议安全组配置 出方向 配置云服务器主机的防火墙什么是防火墙Linux中防火墙的管理工具防火墙的作用常用命令介绍(firewalld&#x…

搜维尔科技:Movella数字化运动领域的领先创新者

下一代游戏、视觉效果、直播、工作场所人体工程学、运动表现、海洋和机器人技术。前所未有的运动成就。让所有年龄段的观众惊叹不已的艺术创新。Movella 的全栈技术用于捕捉、数字化和分析运动,正在让世界变得更美好。 数字艺术家的创造力得到释放 灯光、摄像机、…

C++自研3D教程OPENGL版本---动态批处理的基本实现

又开始找工作了&#xff0c;借机休息出去旅行两个月&#xff0c;顺便利用这段时间整理下以前写的东西。 以下是一个简单的动态批处理实现&#xff1a; #include <GL/glew.h> #include <GLFW/glfw3.h> #include <iostream> #include <vector>// 顶点结…

大语言模型遇上自动驾驶:AsyncDriver如何巧妙解决推理瓶颈?

导读 这篇论文提出了AsyncDriver框架&#xff0c;致力于解决大语言模型在自动驾驶领域应用中的关键挑战。论文的主要创新点在于提出了大语言模型和实时规划器的异步推理机制&#xff0c;实现了在保持性能的同时显著降低计算开销。通过设计场景关联指令特征提取模块和自适应注入…

mapbox进阶,添加绘图扩展插件,绘制圆形

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️MapboxDraw 绘图控件二、🍀添加绘图扩…

JVM 中的四类引用:强、软、弱、虚

导言 在 Java 开发中&#xff0c;垃圾收集&#xff08;GC&#xff09;机制通过自动管理内存提升了开发效率。但你是否知道 JVM 通过四种引用类型&#xff08;强、软、弱、虚&#xff09;精细控制对象生命周期&#xff1f; 强引用&#xff08;Strong Reference&#xff09; 特…