PerturBench:单细胞扰动预测的标准化基准测试框架解析
1. 项目概述为什么我们需要一个统一的扰动预测基准测试库如果你在单细胞转录组学或者计算生物学领域做过一阵子尤其是尝试过构建或复现那些预测基因敲除、药物处理后细胞表达谱变化的模型那你大概率经历过这样的痛苦好不容易从GitHub上扒拉下来一个模型的代码吭哧吭哧配好环境跑通了作者提供的示例数据结果想换到自己的数据集上或者想跟另一个SOTA模型公平地比一比发现根本无从下手。数据格式千奇百怪预处理流程各说各话评估指标五花八门甚至连“训练-验证-测试”划分都能搞出七八种不同的定义。最后你花在数据对齐和工程实现上的时间可能比研究模型本身还要多。这就是单细胞扰动预测领域长期存在的一个痛点缺乏一个公认的、标准化的基准测试框架。大家各玩各的论文里的漂亮数字换个环境可能就复现不出来模型之间的比较也常常是“关公战秦琼”说服力有限。PerturBench的出现就是为了填上这个坑。它不是一个新模型而是一个基于PyTorch的开源软件库和基准测试套件目标很明确为这个领域的研究者提供一套统一的“尺子”和“跑道”。简单来说PerturBench想解决三个核心问题可复现性让任何人拿到一篇论文的模型都能用同一套数据、同一种划分、同一个评估流程跑出可比较的结果。可扩展性让研究者能轻松地将自己的新模型、新数据集接入这个生态快速进行基准测试而不是从头造轮子。深入评估不仅看传统的RMSE、余弦相似度还要引入更鲁棒、更能揭示模型本质问题的评估指标比如它重点提出的排序指标Rank Metrics专门用来诊断模型是否发生了“模式崩溃”Mode Collapse——也就是模型不管输入什么扰动都输出差不多的结果这种“偷懒”行为在生成式模型中很常见但传统指标很难发现。这个库主要面向两类人一是模型开发者你可以用它作为脚手架快速搭建和训练你的新模型并确保你的评估是标准、公平的二是模型使用者或评估者你可以用它来系统性地比较不同模型在你关心的任务比如跨细胞系预测、组合扰动预测上的优劣为你的生物学问题选择最合适的工具。接下来我会结合自己折腾这类模型的经验带你深入拆解PerturBench的设计哲学、核心模块并分享一些在真实使用中可能遇到的“坑”和实战技巧。2. 核心设计思路模块化、兼容性与面向未来的架构PerturBench的架构设计充分体现了其“基准测试平台”的定位核心思路是高内聚、低耦合同时最大化地拥抱社区已有的成熟工具链。它不是要取代什么而是要成为连接它们的“胶水”。2.1 基石技术栈的选择站在巨人的肩膀上PerturBench没有重复发明轮子而是明智地选择了几个在各自领域已成为事实标准的库作为基础这大大降低了用户的学习和使用成本。PyTorch深度学习模型构建的绝对主流。Pytorch的动态计算图和灵活的Dataset/DataLoader抽象为定义复杂的数据加载逻辑比如单细胞数据特有的控制组匹配提供了完美的基础。PerturBench的数据模块就是构建在PyTorch的Dataset抽象之上的。PyTorch Lightning用过纯PyTorch写训练循环的人都知道里面充斥着大量重复的样板代码设备管理、梯度累积、日志记录、检查点保存等。PyTorch Lightning将这些抽象成LightningModule和LightningDataModule让研究者能更专注于模型结构和数据逻辑本身。PerturBench的基类PerturbationModel就继承自LightningModule使得模型训练、验证、测试的流程变得极其简洁和标准化。Hydra做基准测试意味着有海量的超参数、数据集路径、模型配置需要管理。用一堆命令行参数或者散落的配置文件很容易失控。Hydra提供了一个强大的层级化配置管理系统支持从命令行动态组合配置还能方便地集成超参数优化工具如Optuna。这使得复现任何一个实验比如“用CPA模型在Norman19数据集上跑第6号数据划分”只需要一条清晰的命令。AnnData单细胞数据分析领域的“普通话”。AnnData对象将基因表达矩阵X、细胞元数据obs、基因元数据var以及各种中间计算结果obsm,obsp,layers统一封装在一个数据结构里。PerturBench将AnnData作为数据的输入、内部交换和输出的标准格式。这意味着你可以直接用Scanpy等工具预处理你的数据保存为.h5ad文件然后无缝喂给PerturBench。模型预测的输出也是一个AnnData对象你可以直接用Scanpy进行下游的可视化或分析。这种设计消除了格式转换的麻烦是工程上的一大亮点。实操心得这种技术栈选择非常“接地气”。一个新手研究者很可能已经熟悉了PyTorch和Scanpy/AnnData。PerturBench在此基础上搭建学习曲线相对平缓。当你需要调试时由于底层都是这些熟悉的库你很容易定位问题是出在数据预处理、模型定义还是训练流程上。2.2 核心抽象数据、模型、评估的三位一体PerturBench的代码库清晰地分为三个核心模块这也是大多数机器学习项目的通用范式但在单细胞扰动预测这个具体领域做了高度特化。perturbench.data定义了如何将单细胞数据AnnData转化为模型可用的张量Tensor。其核心是Example这个命名元组NamedTuple它规范了一个样本一个细胞必须包含和可选包含哪些信息。比如gene_expression基因表达向量和perturbations施加的扰动列表是必需的而controls匹配的控制细胞表达和covariates协变量如细胞类型、供体等是可选的用于支持更复杂的任务如控制匹配、条件预测。perturbench.model提供了模型基类PerturbationModel。任何想要接入PerturBench基准测试的模型都需要继承这个类并实现关键的predict方法给定一个反事实批次数据预测扰动后的表达。基类已经帮你封装好了PyTorch Lightning的训练/测试循环、优化器配置、以及调用评估流程的钩子。你只需要关心模型的前向传播逻辑。perturbench.analysis这是评估部分的核心。它提供了一套统一的评估指标计算流程并且设计成了类似Kaggle竞赛的API风格你只需要将模型预测结果一个AnnData对象丢给Evaluator指定任务名称如sciplex3-transfer它就会返回一个包含所有指标结果的DataFrame。这极大地简化了评估流程保证了公平性。这种模块化设计的好处是你可以像搭积木一样使用它。你可以只用它的数据模块来加载和处理你的数据然后用你自己的训练框架。你也可以只继承它的模型基类来获得标准化的评估接口。当然最完整的用法是三者全用享受一站式的便利。3. 数据抽象详解从AnnData到模型输入的关键转换数据是机器学习的燃料在单细胞扰动预测中数据准备更是重中之重且异常繁琐。PerturBench在数据抽象上花了很大功夫旨在覆盖从简单到复杂的各种实验设计。3.1 核心数据结构Example与Batch一切始于Example类。它定义了一个样本一个细胞的标准结构。理解每个字段的含义至关重要class Example(NamedTuple): gene_expression: Tensor # 必需该细胞的基因表达向量例如top 4000高变基因的logCPM值 perturbations: list[str] # 必需应用于该细胞的扰动列表如 [‘TP53_knockout‘]组合扰动则为 [‘DrugA‘, ‘DrugB‘] covariates: dict[str, str] # 可选协变量字典如 {‘cell_type‘: ‘T细胞‘, ‘donor‘: ‘D1‘} controls: Tensor # 可选匹配的控制细胞表达向量用于某些需要显式控制匹配的模型或损失函数 gene_names: list[str] # 可选基因名列表与gene_expression的维度顺序对应 extra: dict[str, Any] # 可选其他任意额外信息如预计算的细胞嵌入cell embeddings一个Batch就是多个Example的集合通常由PyTorch的DataLoader自动组装而成并增加了批次维度。注意事项perturbations字段是列表这巧妙支持了组合扰动Combinatorial Perturbations的表示。这是很多早期工具忽略的一点。例如同时敲除基因A和B就表示为[‘GeneA_KO‘, ‘GeneB_KO‘]。协变量covariates使用字典也使得多条件、多因素的实验设计如不同细胞系不同药物处理能够被统一表征。3.2 四类数据集Dataset及其应用场景PerturBench提供了四种PyTorchDataset类对应不同的任务阶段和实验设计SingleCellPerturbation基础训练数据集。它从AnnData对象中读取表达矩阵和元数据构建一个用于模型训练的数据集。每个样本返回一个Example。它通过工厂方法from_anndata创建需要指定哪个obs列是扰动标签perturbation_key以及如何分隔组合扰动中的多个扰动名称perturbation_combination_delimiter例如用分号;。SingleCellPerturbationWithControls带控制匹配的训练数据集。它继承自上一个类但增加了一个关键功能为每个扰动细胞动态采样一个匹配的控制细胞。匹配通常基于协变量例如同一个供体、同一种细胞类型、同一块培养板。这个功能对于某些强调“差异效应”的模型或损失函数如基于对照的损失非常有用。它在内部维护了一个从协变量组合到控制细胞索引的映射表control_indexes以实现高效采样。Counterfactual反事实预测数据集。用于模型推理inference阶段。它的目的不是提供有标签的数据而是定义“你想让模型预测什么”。你给它一组控制细胞control_expression和一组你感兴趣的“反事实”扰动条件perturbations和covariates它就会生成相应的Batch输入给模型的predict方法。例如你可以问“如果把这些原本健康的T细胞控制组都用药物X处理它们的表达谱会变成什么样”CounterfactualWithReference带真实参考的反事实预测数据集。这是评估阶段的核心。它继承自Counterfactual但额外包含了一个真实的、观测到的扰动数据集reference_adata。同时它还有一个映射表reference_indexes能将每个“反事实查询”特定扰动特定协变量映射到reference_adata中对应的、真实观测到的细胞索引上。这样在评估时系统就能自动将模型的预测值与真实观测值进行逐对比较。实操心得理解这四种Dataset的区别是正确使用PerturBench的关键。很多错误源于用错了数据集类型。比如在评估模型时你必须使用CounterfactualWithReference来构建测试集这样才能获得用于计算指标的真实标签。SingleCellPerturbationWithControls在训练时非常有用特别是当你的数据集中扰动细胞和对照细胞数量不平衡时它能确保每个批次内都有合理的对照信号。3.3 数据划分策略超越简单的随机分割在基准测试中如何划分训练集、验证集和测试集直接决定了任务的难度和评估的公正性。PerturBench实现了三种更具生物学意义和挑战性的划分策略通过DataSplitter类来管理跨协变量划分这是为了测试模型的泛化能力。例如在训练集中包含细胞类型A、B、C的扰动数据但在测试集中要求模型预测细胞类型D训练时从未见过对某些扰动的响应。这模拟了现实场景我们无法为所有细胞类型都做一遍实验。组合划分这是为了测试模型对高阶相互作用的理解。训练集包含单一扰动A、B、C的数据以及部分组合扰动如AB的数据。测试集则包含训练时未见过的组合如AC或BC甚至三重组ABC。这要求模型能够推理出扰动之间的非线性叠加效应。逆向组合划分这是一个更刁钻的任务。训练集包含了组合扰动AB的数据以及单一扰动B的数据。测试集要求模型预测单一扰动A的效应。这测试了模型能否从组合效应中“解耦”出单个成分的贡献。DataSplitter通过两个主要参数来控制划分的严格程度m最多保留的协变量类型数量和f每个协变量类型中保留的扰动比例。m越大保留的协变量越少任务越难f越大保留的扰动越少任务也越难。这种设计允许研究者系统地研究数据量、数据平衡性对模型性能的影响。4. 模型接口与评估流程如何让你的模型“上车”4.1 实现一个兼容PerturBench的模型要让你的自定义模型能在PerturBench的框架下训练和评估你需要创建一个继承自PerturbationModel的类。这个基类本身也是LightningModule已经为你处理了大部分样板代码。你需要完成的最核心工作是实现抽象方法predict(self, counterfactual_batch: Batch) - torch.Tensor。这个方法接收一个由Counterfactual数据集产生的批次数据返回模型预测的扰动后基因表达张量。这个Batch里包含了控制细胞的表达、协变量信息以及你想要施加的扰动列表。此外你通常还需要重写__init__方法来定义你的网络结构以及forward方法可选用于训练时的前向传播。configure_optimizers方法基类提供了默认的Adam优化器你也可以按需重写。import torch import pytorch_lightning as pl from perturbench.model.base import PerturbationModel class MyCustomModel(PerturbationModel): def __init__(self, n_genes: int, perturbation_vocab_size: int, hidden_dim: int 128): super().__init__() # 定义你的网络层 self.pert_embedding torch.nn.Embedding(perturbation_vocab_size, hidden_dim) self.encoder torch.nn.Linear(n_genes hidden_dim, hidden_dim) self.decoder torch.nn.Linear(hidden_dim, n_genes) # 可以在这里保存一些必要的配置到 self.training_record self.training_record[‘n_genes‘] n_genes def predict(self, counterfactual_batch: Batch) - torch.Tensor: 核心推理方法。 counterfactual_batch 包含 - control_expression: 控制细胞表达 [batch_size, n_genes] - perturbations: 扰动列表的列表 - covariates: 协变量字典 # 1. 获取控制细胞表达 x_control counterfactual_batch.gene_expression # 2. 将扰动名称转换为嵌入这里需要你自己的词表映射逻辑 pert_indices self._convert_pert_names_to_indices(counterfactual_batch.perturbations) pert_emb self.pert_embedding(pert_indices) # [batch_size, hidden_dim] # 3. 将控制表达和扰动嵌入拼接通过网络 combined torch.cat([x_control, pert_emb], dim-1) latent self.encoder(combined) x_pred self.decoder(latent) # 预测的扰动后表达 return x_pred def _convert_pert_names_to_indices(self, pert_list): # 实现将扰动名称字符串映射到索引ID的逻辑 # 例如使用一个预定义的字典 pass def training_step(self, batch, batch_idx): # 如果你需要自定义训练步骤例如使用不同的损失函数可以重写这里 # 否则基类会使用默认的MSE损失在 predict 的输出和真实标签之间计算 x_real, y_real batch # batch 来自 SingleCellPerturbation 数据集 y_pred self.forward(x_real) # 需要你实现 forward 方法 loss torch.nn.functional.mse_loss(y_pred, y_real) self.log(‘train_loss‘, loss) return loss注意事项predict方法的输入是Counterfactual批次输出是预测的表达值而不是表达值的变化量Δ。这意味着模型需要学习从“控制状态”到“扰动状态”的完整映射。你的模型内部如何处理控制细胞信息是直接拼接还是作为解码器的条件抑或是用其他方式完全由你决定。这给了模型设计很大的灵活性。4.2 训练与评估的标准化流程得益于PyTorch Lightning和Hydra的集成训练一个模型变得非常简洁。通常会有一个train.py脚本它通过Hydra读取配置文件。配置文件中定义了模型架构、数据集路径、数据划分策略、训练超参数学习率、批次大小、轮数、日志记录器等所有设置。运行训练就像这样python train.py modelmy_custom_model datasetnorman19 splitcross_covariate训练结束后PyTorch Lightning的Trainer会自动调用test_step使用CounterfactualWithReference测试集对模型进行评估。所有指标会被计算并汇总。4.3 核心评估指标解读超越RMSEPerturBench计算一套综合的评估指标从不同角度衡量预测质量RMSE (均方根误差)最直接的回归指标衡量预测表达值与真实值在绝对数值上的差距。但它对异常值敏感且无法衡量分布层面的相似性。Cosine Similarity of LogFC (对数倍变化余弦相似度)这是一个更生物相关的指标。它先计算每个基因在扰动组 vs 控制组之间的对数倍变化Log Fold Change然后比较预测的LogFC向量与真实的LogFC向量之间的余弦相似度。这个指标关注的是变化的方向和模式而不是绝对表达值对于很多下游分析如通路富集更有意义。MMD (最大均值差异)一个衡量两个概率分布差异的指标。PerturBench分别在原始基因表达空间MMD GEX和PCA降维后的空间MMD PCA计算MMD。这用于评估模型预测的细胞群体分布是否与真实扰动后的细胞群体分布一致。一个好的模型应该能捕捉到扰动引入的细胞异质性而不仅仅是预测一个“平均”响应。DEG Recall (差异表达基因召回率)识别哪些基因因扰动而发生显著变化差异表达基因DEG是许多生物学分析的关键。这个指标计算模型预测出的top N个变化最大的基因中有多少是真实的DEG。它衡量模型捕捉最显著生物学信号的能力。Rank Metrics (排序指标)这是PerturBench的一大创新。对于每个测试样本一个特定的扰动-协变量组合它计算模型预测值与所有其他真实观测样本之间的相似度如余弦相似度然后看真实匹配的样本在其中排第几名排名越靠前越好最后对所有测试样本取平均。这个指标非常严苛因为它要求模型不仅预测得准还要在众多可能的“错误答案”中脱颖而出。它能有效暴露“模式崩溃”问题——如果一个模型对所有输入都输出相似的结果那么它的预测与真实匹配样本的相似度排名可能会很差即使它的RMSE看起来还不错。5. 实战经验与避坑指南在复现和使用PerturBench以及基于它进行开发的过程中我踩过不少坑也总结出一些让流程更顺畅的经验。5.1 数据预处理的一致性成败的关键基准测试的核心是公平比较而公平比较的前提是所有模型使用完全相同的数据。PerturBench虽然提供了数据加载模块但原始数据的预处理如基因过滤、归一化仍需用户自己完成并保存为标准的AnnData文件。这里有几个关键点高变基因选择PerturBench建议使用Seurat v3方法选择top 4000个高变基因。务必记录下这些基因的列表。在将新数据或另一个模型所需的数据输入基准测试前必须将其投影到完全相同的基因空间。如果两个模型用了不同的基因集比较就失去了意义。一个实用的做法是在第一次处理基准数据集时就将筛选出的高变基因列表保存下来作为该数据集“官方”特征集。归一化与缩放是使用CPM每百万计数、logCPM还是SCTransform的残差不同的归一化方法会极大影响数据的分布从而影响模型性能。必须在整个基准测试中固定使用一种方法。PerturBench的示例通常使用log1p(CPM)后的数据。扰动与协变量编码确保你的AnnData对象的obsDataFrame中扰动和协变量列的名称、格式与PerturBench的from_anndata工厂方法所期望的一致。特别是对于组合扰动要明确分隔符如;或并在所有数据集中保持一致。控制细胞的明确标识数据集中必须有一列来明确标识哪些细胞是未受扰动的控制组Control。通常这列的值是一个特定的字符串如‘control‘或‘NT‘Non-Targeting。在构建SingleCellPerturbationWithControls数据集时需要正确指定perturbation_control_value参数。5.2 模型实现中的常见陷阱predict方法的输入输出维度这是最容易出错的地方。务必清楚counterfactual_batch.gene_expression的形状是[batch_size, n_genes]你的模型输出也必须是这个形状。如果你的模型内部使用了基因嵌入或注意力机制要确保最终解码回原始基因空间。扰动嵌入的处理许多先进模型如CPA, GEARS会使用预训练的扰动嵌入如药物的分子指纹、基因的GO注释向量。如果你在模型中使用这类嵌入需要在Example的extra字段中提供或者在模型内部实现一个嵌入层。要确保在训练和推理时嵌入的查找方式一致。与PyTorch Lightning的配合PerturbationModel是LightningModule的子类。这意味着你需要遵循PyTorch Lightning的规范。例如在training_step中计算损失并使用self.log记录优化器配置在configure_optimizers中定义。如果你需要更复杂的训练逻辑如对抗训练、多任务学习可能需要仔细重写这些方法。设备管理PyTorch Lightning会自动处理设备CPU/GPU放置。但在你的模型forward或predict方法内部如果进行了自定义的张量操作如从外部文件加载嵌入矩阵要确保这些张量也在正确的设备上。可以使用self.device属性来获取模型当前所在的设备。5.3 评估阶段的关键步骤正确构建测试集评估必须使用CounterfactualWithReference数据集。你需要提供control_expression用于生成反事实预测的控制细胞表达矩阵。perturbations和covariates你想要评估的扰动协变量对列表。reference_adata包含真实扰动细胞观测值的AnnData对象。reference_indexes一个字典将每个扰动协变量组合映射到reference_adata中的细胞索引列表。构建这个映射表是评估准备中最容易出错的一环需要仔细检查每个组合在参考数据中是否都有对应的观测细胞。理解评估输出Evaluator.evaluate()返回的DataFrame包含了所有指标在每个模型上的结果。注意有些指标是“越高越好”如Cosine LogFC有些是“越低越好”如RMSE, MMD。排序指标*_rank的值在0到1之间越接近0越好排名越靠前。仔细阅读文档理解每个指标的确切含义。模式崩溃的诊断如果你的模型在RMSE上表现尚可但排序指标特别是transposed-rank非常差接近0.5即随机水平那么很可能发生了模式崩溃。此时应该像PerturBench论文附录中那样绘制预测结果之间的余弦相似度矩阵热图。如果热图看起来“块状”明显同一协变量下的不同扰动预测结果相似或者整体颜色单一那就是模式崩溃的典型标志。这时需要检查模型容量是否不足、正则化是否过强、或者训练数据是否存在严重不平衡。5.4 性能优化与调试技巧利用Hydra进行超参数扫描PerturBench与Hydra和Optuna的集成使得超参数优化变得非常方便。你可以定义一个超参数搜索空间然后并行运行数十个实验。这对于寻找最佳模型配置至关重要。记得合理设置搜索范围并利用Hydra的多运行multirun功能。注意内存使用单细胞数据矩阵可能非常大。使用SingleCellPerturbation数据集时如果整个表达矩阵加载到内存中导致OOM内存溢出可以考虑使用torch.utils.data.Subset进行子采样或者使用IterableDataset进行流式加载但这可能会增加代码复杂度。Counterfactual数据集通常较小因为只涉及控制细胞。日志与可视化充分利用PyTorch Lightning的日志回调如TensorBoardLogger。监控训练损失、验证损失以及任何你自定义的指标。对于生成式模型定期可视化一些样本的预测结果与真实值的对比例如选择几个关键基因绘制预测值与真实值的散点图可以给你带来比数字指标更直观的感受。从简单基线开始在实现复杂的SOTA模型之前强烈建议先实现并运行PerturBench提供的几个基线模型如LatentAdditive或DecoderOnly。这能帮助你快速验证整个PerturBench pipeline在你的机器和环境上是畅通的并且能建立一个性能基准。之后任何复杂模型都应该显著超越这些基线否则就需要反思是模型问题还是实现bug。6. 总结与展望PerturBench的生态价值折腾完PerturBench这一套我的感受是它确实为单细胞扰动预测这个快速发展的领域注入了一剂“标准化”的强心针。它通过严谨的软件工程实践把数据接口、模型训练、评估指标这些琐碎但关键的部分给统一和自动化了让研究者能更专注于模型创新本身。它的几个设计选择我认为尤其值得称道一是深度绑定AnnData无缝对接现有单细胞分析生态二是引入排序指标来诊断模式崩溃戳中了许多生成式模型的软肋三是提供了多种具有生物学意义的数据划分策略推动模型向更具泛化能力的方向发展。当然作为一个工具它也有其边界和可扩展的方向。例如目前它主要处理的是静态的、单时间点的扰动响应预测。对于时间序列的扰动数据如scRNA-seq时间序列或者需要整合多组学信息如ATAC-seq, 蛋白质组的预测任务当前的抽象可能需要进一步扩展。此外对于如何将先验知识如基因调控网络、通路数据库更有效地融入模型框架PerturBench提供了extra字段这样的扩展口但更深入的集成可能需要社区共同探索。对于想要进入这个领域的新手我的建议是把PerturBench的代码仓库克隆下来从运行它的教程Notebook开始尤其是复现一两个基线模型在某个数据集如Srivatsan20上的结果。在这个过程中你会被迫理解它的数据流、模型接口和评估流程。一旦跑通你就掌握了在这个领域进行可复现研究的“标准语言”。之后无论是评估别人的模型还是发布自己的模型你都会发现沟通和比较的效率大大提升。最终像PerturBench这样的基准测试框架其价值不仅在于提供了一组数字排名更在于它定义了一套清晰的“游戏规则”促进了开源、透明和累积性的科学研究。它让这个领域的进步变得可以被客观地衡量和比较。