ALMO框架:基于概率建模的多模态小样本学习实践
1. 项目概述当数据稀缺遇上信息冗余在机器学习的实际应用中我们常常面临一个看似矛盾的局面一方面对于某些特定类别我们可能只有寥寥几张图片甚至一张都没有比如识别一种新发现的鸟类或一个罕见的手写字符另一方面我们手头却可能拥有关于这个类别的丰富“旁白”信息——一段文字描述、一组属性标签甚至是绘制它的笔触轨迹。传统的小样本学习或零样本学习方法往往只盯着那几张可怜的图片“死磕”而忽略了这些宝贵的多模态信息。这就像让你只凭一张模糊的远景照片去辨认一只鸟却不告诉你它“喙长、腿细、羽毛呈蓝绿色”这些关键特征难度可想而知。ALMO框架的提出正是为了解决这个痛点。它的核心思想非常直观既然单一模态如图像的信息在数据稀缺时捉襟见肘那我们为何不把所有可用的信息——图像、文本、属性、声音等——都利用起来ALMO通过一套概率建模的“翻译”机制将不同模态、不同形式的数据全部映射到一个统一的、低维的随机潜在空间中。在这个空间里每个类别不再是一个孤立的点而是一个概率分布比如一个高斯分布这个分布综合了来自所有模态的证据。当面对一个新样本时我们同样将其映射到这个空间然后计算它与各个类别分布的距离或相似度从而做出判断。这种方法的高明之处在于其不确定性量化。对于只有一张图片的类别模型知道这个类别的表示不确定性很高分布方差大因此在融合时会更多地依赖文本描述等更可靠的模态信息。这种动态的、自适应的信息融合是ALMO相比简单拼接或平均池化等确定性方法的核心优势。我在复现和实验过程中深刻体会到正是这种对不确定性的显式建模让模型在面对极端数据稀缺时表现出了惊人的鲁棒性。2. 核心思路拆解从确定性原型到概率化融合要理解ALMO我们需要先看看它要改进的“前辈”们是怎么做的。传统的小样本学习方法如原型网络其流程可以概括为1用神经网络将支持集Support Set中每个类别的少数样本映射为一个特征向量2对这些特征向量取平均得到该类别的“原型”一个确定性的点3将查询集Query Set的样本同样映射为特征向量4计算查询样本与各个类别原型之间的欧氏距离用最近邻或softmax进行分类。这个流程存在两个明显问题。第一确定性表示的脆弱性。从寥寥几个样本计算出的一个“平均点”作为原型极易受到噪声样本或异常值的影响且无法表达“这个类别的样本可能分布在一个区域而非一个点”这种认知。第二多模态信息融合的粗暴性。当支持集中既有图像又有文本时常见做法是分别提取特征后直接拼接或取平均。这假设了所有模态同等重要、同等可靠但现实中图像可能模糊文本描述可能含混这种“一刀切”的融合方式并不合理。ALMO的解决方案是进行概率化升级和贝叶斯式融合。2.1 概率化表示从点到分布ALMO不再为每个查询样本或类别原型输出一个确定性的特征向量而是输出一个概率分布通常是多元高斯分布由均值向量和对角协方差矩阵参数化。对于查询样本i其潜在变量 $z_i$ 的分布为 $p(z_i | x_{i, m1:M}) \mathcal{N}(z_i | \mu_i, \Sigma_i)$。这里$x_{i, m}$ 代表样本i在第m个模态下的观测如图像像素神经网络编码器 $f_{\theta_m}$ 负责将 $x_{i,m}$ 映射为均值 $\mu_{i,m}$ 和对角协方差 $\Sigma_{i,m}$。对于类别原型j其潜在变量 $\eta_j$ 的分布为 $p(\eta_j | \mathcal{D}_j) \mathcal{N}(\eta_j | m_j, S_j)$。其中 $\mathcal{D}_j$ 在少样本学习中是与类别j相关的所有支持集样本的多模态观测在零样本学习中则是该类别的多模态元数据如属性向量和文本描述。这个转变是根本性的。均值 $\mu$ 可以理解为该样本或类别在潜在空间中最可能的“位置”而对角协方差矩阵 $\Sigma$ 的每个元素则代表了在该维度上的不确定性。方差大说明模型对这个样本或类别的该特征把握不大方差小则说明把握很准。2.2 概率数据融合专家乘积的魅力当存在多个信息源多模态时如何得到一个统一的分布 $p(z_i | x_{i, 1}, x_{i, 2}, ...)$ 或 $p(\eta_j | \mathcal{D}_j)$ALMO采用了专家乘积Product of Experts, PoE这一优雅的贝叶斯方法。其思想是将每个模态的编码器看作一个“专家”每个专家都对潜在变量 $z$ 应该是什么分布有自己的“意见”一个高斯分布。最终的共识分布就是所有这些专家意见的乘积并重新归一化。对于高斯分布这个乘积有闭合解。具体到查询样本i的融合 假设我们有M个模态每个模态的编码器给出了分布 $\mathcal{N}(z_i | \mu_{i,m}, \Sigma_{i,m})$。根据PoE融合后的分布 $p(z_i | x_{i, 1:M}) \mathcal{N}(z_i | \mu_i, \Sigma_i)$ 的参数为$$ \Sigma_i^{-1} \sum_{m1}^{M} \Sigma_{i,m}^{-1} $$ $$ \mu_i \Sigma_i \left( \sum_{m1}^{M} \Sigma_{i,m}^{-1} \mu_{i,m} \right) $$这个公式蕴含了深刻的直觉精度加权融合后的精度矩阵协方差矩阵的逆$\Sigma_i^{-1}$ 是各个模态精度矩阵的和。这意味着不确定性越小方差越小的模态其精度越高在融合中的话语权就越大。如果一个模态的图像很模糊其方差大、精度低它对最终融合结果的贡献就会自动被削弱。加权平均融合后的均值 $\mu_i$ 是各模态均值的加权平均权重正是各模态的精度矩阵 $\Sigma_{i,m}^{-1}$。高精度的模态均值会获得更高的权重。对于类别原型j的融合以少样本为例 原理类似但需要融合的信息更多每个支持集样本的每个模态都会贡献一个分布。假设支持集 $\mathcal{S}_{e,j}$ 中有 $K$ 个样本每个样本有 $M$ 个模态那么融合公式变为$$ S_j^{-1} \sum_{i \in \mathcal{S}{e,j}} \sum{m1}^{M} \Sigma_{i,m}^{-1} $$ $$ m_j S_j \left( \sum_{i \in \mathcal{S}{e,j}} \sum{m1}^{M} \Sigma_{i,m}^{-1} \mu_{i,m} \right) $$这相当于把所有支持样本的所有模态都视为独立的“专家”共同投票决定该类别的原型分布。样本越多、模态质量越高最终原型的估计就越精准方差越小。实操心得数值稳定性是关键在代码实现中直接计算协方差矩阵的逆 $\Sigma^{-1}$ 可能不稳定特别是当方差估计值很小时。ALMO论文的附录B给出了一个非常实用的技巧神经网络输出的是方差的对数$\log \sigma^2$。在融合时我们计算的是 $\exp(-\log \sigma^2)$这等价于 $1/\sigma^2$即精度。通过使用log-sum-exp技巧来计算这些指数项的求和可以极大提升数值稳定性避免在训练早期因方差估计不稳定而导致的梯度爆炸或消失。这是复现ALMO时必须注意的工程细节。2.3 分类决策在分布间计算距离得到了查询样本的分布 $p(z_i) \mathcal{N}(\mu_i, \Sigma_i)$ 和每个类别的原型分布 $p(\eta_j) \mathcal{N}(m_j, S_j)$ 后如何分类原型网络的思想是计算距离。在概率框架下我们需要计算两个分布之间的距离。ALMO采用了期望负平方欧氏距离作为度量。对于一个查询样本 $z_i \sim \mathcal{N}(\mu_i, \Sigma_i)$ 和一个类别原型 $\eta_j \sim \mathcal{N}(m_j, S_j)$它们之间平方欧氏距离的期望为$$ \mathbb{E}[||z_i - \eta_j||2^2] \sum{l1}^{L} [(\mu_{i,l} - m_{j,l})^2 \sigma_{i,l}^2 s_{j,l}^2] $$其中 $l$ 索引潜在空间的维度。这个公式非常直观距离由三部分组成1两个均值之差的平方2查询样本在该维度上的方差3类别原型在该维度上的方差。不确定性方差直接增加了期望距离。这意味着即使两个分布的均值很接近但如果其中任何一个分布非常不确定方差大模型也会认为它们“相距较远”从而降低将其归为同一类的置信度。最终的分类概率通过softmax函数基于这些期望距离计算得出$$ p(y_i j | x_i, \mathcal{S}) \propto \exp(-\mathbb{E}[||z_i - \eta_j||_2^2]) $$模型训练的目标就是最大化查询样本真实标签的似然概率。3. 实现细节与实操要点理解了核心思想后我们来看看如何具体实现ALMO。整个流程可以分为编码器设计、训练循环Episode Training和推断三个阶段。3.1 编码器网络设计ALMO框架本身不限定编码器的具体结构你可以根据模态类型自由选择。论文中给出了两个数据集的示例对于Omniglot数据集图像笔触序列图像模态使用一个简单的4层卷积网络Conv2D ReLU最后接两个平行的全连接层分别输出均值向量 $\mu$ 和对角方差的对数 $\log \sigma^2$。笔触模态笔触数据是变长序列。首先进行零填充Zero Padding到固定长度然后通过一个带掩码的LSTM层最后同样接两个全连接层输出均值和方差对数。对于CUB-200数据集图像特征属性向量文本描述特征所有模态都使用线性映射全连接层。这是因为输入已经是预提取好的高层特征如图像的ResNet特征、属性向量、文本描述的嵌入向量结构简单的线性层足以学习到潜在空间的映射且能防止过拟合。同样每个模态对应一个编码器输出该模态下的均值和方差对数。注意事项方差参数化的技巧神经网络直接输出方差 $\sigma^2$ 可能不稳定因为它必须为正数。标准做法是让网络输出一个无约束的实数 $l$然后通过 $\sigma^2 \exp(l)$ 或 $\sigma^2 \text{softplus}(l)$ 将其转换为正数。ALMO论文采用前者即输出 $\log \sigma^2$。在计算时精度 $\Sigma^{-1}$ 就是 $\exp(-\log \sigma^2)$非常方便。3.2 训练循环元学习与Episode构造ALMO采用基于Episode的训练这是小样本学习的标准做法目的是让训练过程模拟测试时的任务。构造一个Episode从数据集中随机抽取N个类别如5-way或20-way。对于每个类别随机抽取K个样本作为支持集Support Set再抽取另外K个或一批样本作为查询集Query Set。在零样本设置中支持集里没有该类别的图像样本只有其元数据属性、描述。前向传播支持集处理将支持集中每个样本的每个模态通过对应的编码器得到其均值和方差 $(\mu_{i,m}, \Sigma_{i,m})$。然后按照公式(9)为每个类别j融合所有样本的所有模态得到类别原型分布 $(m_j, S_j)$。查询集处理将查询集中每个样本的每个模态通过编码器得到 $(\mu_{i,m}, \Sigma_{i,m})$。然后按照公式(8)融合多模态得到查询样本的分布 $(\mu_i, \Sigma_i)$。计算损失对于每个查询样本i计算它与本Episode中所有N个类别的原型分布之间的期望负平方欧氏距离并通过softmax得到分类概率。损失函数是查询样本真实标签的负对数似然交叉熵。反向传播与优化计算损失关于所有编码器参数 $\theta_m, \phi_m$ 的梯度并使用Adam等优化器更新参数。这个过程不断重复每次迭代都从数据集中采样新的Episode。模型学习到的是如何从任意一组新类别、新样本中快速提取并融合信息做出准确分类的“元能力”。3.3 零样本学习的特殊处理零样本学习与少样本学习的主要区别在于支持集的内容。在零样本中支持集没有图像样本 ${x_{i}}$只有类别的元数据 ${a_{j,r}}$如属性向量、文本描述。因此在计算类别原型分布 $p(\eta_j | a_{j,1:R})$ 时公式(9)需要修改。我们为元数据的每个模态r设置一个编码器 $g_{\phi_r}$将元数据 $a_{j,r}$ 映射为分布 $(\mu_{j,r}, \Sigma_{j,r})$。然后同样使用专家乘积公式进行融合$$ S_j^{-1} \sum_{r1}^{R} \Sigma_{j,r}^{-1} $$ $$ m_j S_j \left( \sum_{r1}^{R} \Sigma_{j,r}^{-1} \mu_{j,r} \right) $$查询样本的处理与少样本学习完全相同。这样在测试时即使遇到一个从未见过图像的新类别如“雪鸮”只要我们知道它的属性“白色羽毛”、“夜行性”、“脸盘状”和一段描述模型就能在潜在空间中为其构建一个原型分布从而对一张雪鸮的图片进行分类。4. 实验复现与性能分析为了验证ALMO的有效性我按照论文描述在Omniglot和CUB-200数据集上进行了复现实验。以下是关键设置和发现。4.1 数据集与实验设置数据集模态任务支持集内容查询集内容评价指标Omniglot图像 笔触序列少样本分类N类每类K个样本图像笔触同N类每类一批样本图像笔触分类准确率CUB-200图像特征 属性 文本描述零样本/少样本分类零样本N类的属性描述少样本N类每类K个样本图像属性描述同N类的图像分类准确率网络结构与论文保持一致。Omniglot使用CNN和LSTMCUB-200使用线性层。潜在空间维度L256。训练使用Adam优化器学习率1e-3。Omniglot训练10万个episodeCUB-200训练3万个episode并用验证集进行早停。对比基线包括原型网络PROTO、匹配网络MN、VERSA、TapNet、PT-MAP以及多模态原型网络MProto等。4.2 核心结果与洞见1. 少样本学习概率融合的抗过拟合优势在Omniglot的5-way 5-shot任务中ALMO仅使用图像模态就达到了约98.5%的准确率优于原型网络PROTO的约97.8%和VERSA的约98.0%。当任务难度增加到20-way 5-shot时优势更加明显ALMO约94.2%PROTO约89.5%VERSA约92.1%。这清晰地证明了随机潜在空间和概率融合在缓解过拟合方面的作用。PROTO的确定性原型在类别增多、每类样本有限时泛化能力急剧下降。而ALMO通过方差建模不确定性起到了类似正则化的效果。2. 多模态融合112的效果当引入多模态信息后ALMO的性能得到进一步提升。在Omniglot上结合图像和笔触数据ALMO在5-way 1-shot任务上比仅用图像或仅用笔触的版本高出2-3个百分点。在CUB-200的零样本任务中同时使用属性312维向量和视觉描述400维文本特征的ALMO其Top-1准确率达到约62.5%显著高于仅使用属性约58.1%或仅使用描述约59.7%的版本也高于其他基线方法。3. 可解释性模型知道该相信谁ALMO一个非常吸引人的特点是其可解释性。由于融合权重即精度矩阵 $\Sigma^{-1}$是动态计算的我们可以事后分析对于某个特定类别模型更“信任”哪个模态。 在CUB-200的零样本实验中我复现了论文中的分析。例如对于“船尾鹩哥”Boat-tailed Grackle这个类别当同时使用属性和描述时模型分配给视觉描述模态的权重远高于属性模态。检查数据发现该鸟类的文本描述包含了“长而明显的V形尾羽”、“雄性全身有光泽的黑色”等非常具象且独特的视觉信息而属性向量可能无法完全捕捉这些细节。因此模型在融合时自动赋予了描述更高的权重从而显著提升了对该类别的分类准确率。这种能力是简单的拼接或平均池化方法所不具备的。4.3 计算开销与效率考量引入概率建模和融合是否会带来巨大的计算负担实测结果表明开销在可接受范围内。 在相同的网络架构和硬件NV 1080Ti GPU下ALMO处理Omniglot测试集一个查询样本的平均时间约为12.17毫秒5-way 5-shot而确定性的多模态原型网络MProto约为12.07毫秒。额外的0.1毫秒主要来自对每个潜在空间维度计算log-sum-exp和softmax以进行融合。考虑到ALMO带来的性能提升这点开销是微不足道的。 参数方面ALMO比确定性方法多了一组用于预测方差的对角线参数从L维均值增加到2L维参数。但这部分增量通常只占整个网络参数的不到1%因为绝大部分参数都集中在特征提取的主干网络中。5. 常见问题、调参技巧与扩展思考在实际复现和应用ALMO的过程中我遇到并总结了一些典型问题和解决方案。5.1 训练不稳定或性能不佳问题1训练初期损失震荡或变为NaN。可能原因方差估计 $\log \sigma^2$ 不稳定导致在计算精度 $\exp(-\log \sigma^2)$ 时出现数值溢出得到inf或下溢得到0。解决方案初始化技巧将方差预测层的权重初始化为较小的值如0.01偏置初始化为一个较小的负数如-2这样初始输出的方差不会太大也不会太小。激活函数使用 $\sigma^2 \text{softplus}(l)$ 代替 $\sigma^2 \exp(l)$因为softplus函数更平滑梯度更稳定。我在复现时发现这能有效改善训练初期的稳定性。梯度裁剪在反向传播时对梯度进行裁剪如设置max_norm1.0防止梯度爆炸。问题2模型似乎没有利用好多模态信息性能与单模态差不多。可能原因某个模态的编码器训练不足或过于强势导致融合时该模态主导其他模态失效。解决方案平衡的数据流确保每个batch中所有模态的数据都正常存在且经过充分的预处理。对于缺失模态的情况ALMO框架本身可以处理方差设为无穷大即精度为0但在训练初期最好使用完整数据。模态特定的学习率如果不同模态的编码器结构或数据尺度差异很大如图像CNN和文本LSTM可以为它们设置不同的学习率。检查方差输出在验证集上运行时打印出不同模态预测的方差均值。如果一个模态的方差始终非常小精度很高而其他模态方差很大说明模型过于依赖前者。可以尝试在损失中加入一个正则项鼓励所有模态的方差不要过早收敛到极小值。5.2 超参数调优指南超参数建议范围/值影响与说明潜在空间维度 L64, 128, 256, 512维度越高表征能力越强但也更容易过拟合。对于Omniglot这类相对简单的数据128-256足够对于CUB-200256-512可能更好。需要权衡。学习率1e-4 到 1e-3对于线性层/简单CNN1e-3通常可行对于较深的网络或预训练特征建议从1e-4开始。使用学习率衰减。Episode构造 (N, K)N5, K1/5; N20, K1/5训练时的N和K应与你的目标测试任务一致或更难如用20-way训练来增强5-way的泛化能力。优化器Adam (beta10.9, beta20.999)默认Adam效果很好。对于非常不稳定的训练可以尝试SGD with momentum。方差预测层初始化权重: small normal (std0.01)偏置: constant (-2 to 0)关键糟糕的初始化会导致训练崩溃。负的偏置初始值能让初始方差接近1处于一个合理的范围。5.3 框架的局限性与未来扩展ALMO框架虽然强大但并非没有局限了解这些局限能帮助我们在合适的场景应用它并思考改进方向。对角协方差假设为了计算简便和稳定ALMO假设每个模态编码器输出的高斯分布具有对角协方差矩阵。这意味着它假设潜在空间的各个维度是相互独立的。这虽然是一个常见假设如在变分自编码器中但可能过于简化限制了模型捕捉特征间复杂相关性的能力。一个折中的改进方向是预测一个低秩的协方差矩阵或者使用更灵活的分布族如归一化流。融合方式的单一性ALMO采用了专家乘积PoE进行融合。PoE假设各个模态是条件独立的。当这个假设不成立时融合效果可能不是最优。可以探索其他融合方式如专家混合Mixture of Experts, MoE它允许模型学习一个权重来组合不同模态的分布可能更具灵活性。距离度量的选择ALMO使用期望平方欧氏距离。对于某些任务余弦相似度或马氏距离可能更合适。修改距离度量需要重新推导目标函数中log-sum-exp项的上界这增加了复杂性。对预训练特征的依赖在CUB-200实验中图像使用的是预训练的ResNet特征。ALMO的成功部分依赖于这些高质量的单模态特征提取器。在端到端训练中如何同时优化特征提取器和概率融合模块是一个需要仔细设计训练策略的挑战。尽管有这些局限ALMO为多模态小样本/零样本学习提供了一个坚实、可解释且高效的基线。它的核心思想——用概率分布表示不确定性并用精度加权进行自适应融合——具有广泛的启发性。这个框架可以很容易地扩展到三个以上的模态或者与其他先进的元学习、度量学习方法结合。在我自己的探索中尝试将ALMO与基于Transformer的跨模态编码器结合用于处理更复杂的图文检索小样本任务也取得了不错的初步效果。关键在于始终抓住其“概率融合”的内核围绕它来构建和优化你的系统。