ReF-LDM 环境搭建与复现记录(含踩坑总结)
本文记录NeurIPS 2024 论文 ReF-LDM: A Latent Diffusion Model for Reference-based Face Image Restoration的复现过程欢迎批评指正。本文将介绍Ref-LDM 项目环境搭建Ref-LDM 项目结构解析如何利用现有代码构建完整 evaluation pipelineRef-LDM论文结果复现一、背景与介绍图片出自论文《ReF-LDM: A Latent Diffusion Model for Reference-based Face Image Restoration》论文项目地址ChiWeiHsiao/ref-ldm: The official implementation of ReF-LDM: A Latent Diffusion Model for Reference-based Face Image Restoration [NeurIPS 2024]论文链接2412.050431这篇论文当前公开到了什么程度ReF-LDM 是基于Latent Diffusion Model (LDM)的人脸修复模型通过引入reference image来保持人物身份一致性。官方 GitHub 已公开代码、模型权重下载入口、推理脚本inference.pyREADME 里也给了单图推理命令但 README 同时还写着“inference script for testing datasets”在 TODO 里说明完整 benchmark 级批量评测脚本很可能还需要自己补。这意味着复现工作最好分两层第 1 层功能复现跑通官方权重跑通 demo理清输入输出格式看多 reference 的组织方式第 2 层实验复现搭数据集写 batch inference写 metric 计算对齐论文 setting复现实验表格这两层不要混着做不然很容易一开始就陷进数据和评测细节里。2论文里真正需要复现的 setting论文实验设置里比较关键的点有这些分辨率是512×512训练时固定5 张 reference推理时使用100 步 DDIM对 reference 用了classifier-free guidancescale1.5评测数据集包括FFHQ-Ref-Severe、FFHQ-Ref-Moderate、CelebA-Test-Ref指标包括IDS、LPIPS、fLPIPS、NIQE、FID其中 IDS 是用ArcFace cosine similarity算的身份相似度。其中最关键的是IDS因为这题本质上不是只比清晰度而是比“清晰 身份不漂”。论文里 ReF-LDM 在严重退化的 FFHQ-Ref-Severe 上IDS 相比 LDM、CodeFormer、VQFR、DAEFR、DMDNet 都明显更高这也说明它更适合作为“ID 保持”型 baseline。二、环境搭建1 创建 Conda 环境建议使用Python 3.10conda create -n refldm python3.10 -y conda activate refldm2 降级 pip 版本项目依赖pytorch-lightning1.6.3而新版pip会拒绝安装旧 metadata 格式的包因此需要降级 pippython -m pip install pip24.0检查版本pip --version3 安装 PyTorch项目要求torch1.13.1 torchvision0.14.1安装适合的PyTorchpip install torch1.13.1cu117 torchvision0.14.1cu117 \ --extra-index-url https://download.pytorch.org/whl/cu117验证python -c import torch;print(torch.__version__, torch.cuda.is_available())PS.注意该ldm太旧优先换到 A40/A100 等兼容 GPU上运行否则要重新安装torch版本。重新安装如下pip install torch torchvision torchaudio装完验证python -c import torch; print(torch.__version__) python -c import torch; print(torch.cuda.is_available()) python -c import torch; print(torch.cuda.get_arch_list())其他见项目依赖安装如果遇到问题建议不要用requirements.txt而是分步安装package4 安装项目依赖项目提供requirements.txtpip install -r requirements.txt环境搭建中问题总结与解决办法在复现过程中遇到了多个典型问题下面逐一总结。问题 1GitHub 依赖下载失败安装requirements.txt时出现错误fatal: unable to access https://github.com/CompVis/taming-transformers.git/: Failed to connect to github.com port 443原因服务器无法直接访问 GitHub。解决方法方法一推荐手动下载源码git clone https://github.com/CompVis/taming-transformers git clone https://github.com/openai/CLIP上传到服务器后安装或者后面换成源码的绝对路径pip install -e taming-transformers pip install -e CLIP同时在requirements.txt中注释# -e githttps://github.com/CompVis/taming-transformers.gitmaster#eggtaming-transformers # -e githttps://github.com/openai/CLIP.gitmain#eggclip安装后验证python -c import taming; print(taming ok) python -c import clip; print(clip ok)补充taming-transformers依赖问题的解决过程在复现 ReF-LDM 时环境中的torch、clip和ldm都可以正常导入但执行推理脚本时仍报错ModuleNotFoundError: No module named taming进一步排查发现虽然pip install -e ~/Projects/ref-ldm/taming-transformers显示安装成功但在项目根目录下执行python -c import taming仍然失败只有切换到taming-transformers源码目录中时相关模块才可以被导入。说明问题并不在源码缺失而在于taming包没有被 Python 正确识别和暴露到当前环境中。随后检查taming-transformers/taming目录结构确认源码完整包含data、models、modules等子目录但一开始缺少__init__.py。补充一个空的__init__.py文件后重新安装taming-transformers再次测试发现import taming已经可以在 ReF-LDM 根目录下正常执行。最终确认问题根因是taming目录缺少__init__.py导致包虽然显示安装成功但无法被稳定识别为可导入模块补上该文件并重新安装后No module named taming问题得到解决。python -c import taming; print(taming ok)问题 2pytorch-lightning 安装失败报错Ignoring version 1.6.3 of pytorch-lightning since it has invalid metadata Please use pip24.1原因新版本pip不兼容旧版 Lightning 的 metadata。解决方法降级 pippython -m pip install pip24.0然后重新安装依赖。如果安装Pytorch时遇到版本问题建议先安装新版本pip install lightning或者#下载 pip download pytorch-lightning1.8,2.0 -d ./wheels #安装 pip install --no-index --find-links./wheels pytorch-lightning #验证 python -c import pytorch_lightning as pl; print(pl.__version__)问题 3torch 版本不匹配如果先执行pip install -e .可能会自动安装最新torch 2.x。但 ReF-LDM 依赖torch1.13.1解决方法先安装 PyTorch再安装项目依赖。问题4 PyTorch 2.x 加载旧 Lightning checkpoint 的兼容问题在迁移 ReF-LDM 到新的 H20 环境PyTorch 2.11后推理阶段加载vqgan.ckpt时出现如下错误_pickle.UnpicklingError: Weights only load failed... Unsupported global: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint原因是PyTorch 2.6 之后torch.load默认开启weights_onlyTrue的安全加载机制。而 ReF-LDM 的 checkpoint 是在旧版本PyTorch PyTorch-Lightning下保存的其中包含ModelCheckpoint等 Lightning 对象因此在新版本 PyTorch 中会被安全机制阻止反序列化。方案一可以通过torch.load(path, weights_onlyFalse)绕过限制但这会关闭安全检查因此更推荐使用 方案二safe allowlist 方式。方案二最终采用 PyTorch 官方推荐的safe_globals方法解决在加载 checkpoint 时允许 Lightning 的ModelCheckpoint类型from pytorch_lightning.callbacks import ModelCheckpoint import torch with torch.serialization.safe_globals([ModelCheckpoint]): sd torch.load(path, map_locationcpu, weights_onlyTrue)[state_dict]这样既能保持weights_onlyTrue的安全加载策略又能兼容旧 Lightning checkpoint。修改后 ReF-LDM 的inference.py可以正常加载模型并继续推理。5 安装项目本体项目中包含setup.py用于注册ldm包pip install -e .安装完成后验证python -c import ldm; print(ldm ok)完整的导入链测试python -c import torch; print(torch.__version__, torch.cuda.is_available(), torch.cuda.get_arch_list()) python -c import taming; print(taming ok) python -c from taming.modules.vqvae.quantize import VectorQuantizer2; print(quantize ok) python -c import clip; print(clip ok) python -c import ldm; print(ldm ok) python -c import pytorch_lightning as pl; print(pl.__version__)三、模型下载Download modelsRelease ReF-LDM model weight · ChiWeiHsiao/ref-ldm需要下载两个模型权重refldm.ckptvqgan.ckpt并放到ckpts/最终目录结构ref-ldm ├─ assets ├─ ckpts │ ├─ refldm.ckpt │ └─ vqgan.ckpt ├─ configs ├─ ldm ├─ inference.py ├─ requirements.txt └─ setup.py四、运行 Demo运行官方提供的推理脚本python inference.py \ --ddim_step 50 \ --output_path result.png \ --lq_path assets/demo/lq.png \ --ref_paths assets/demo/ref0.png assets/demo/ref1.png assets/demo/ref2.png assets/demo/ref3.png如果环境正确将生成修复后的图像result.png五、论文复现1.项目结构Ref-LDM 的代码仓库采用较为标准的Latent Diffusion 项目结构包含模型实现、推理脚本、数据处理以及评测代码。ckpts/— 模型权重该目录存放训练好的 Ref-LDM 权重configs/— 实验配置该目录保存对应的模型结构配置ldm/— 核心模型实现这是整个项目的核心模块负责Latent Diffusion ModelUNetDDIM sampler数据加载Identity lossscripts/— Pipeline脚本scripts/├── align_face_image.py├── inference_dataset.py└── eval.py这三个脚本实际上已经构成了一条完整 pipeline数据预处理→ 批量推理→ 指标评测data/— 数据集自行创建存放FFHQFFHQ-RefCelebA-Test-Ref在复现 Ref-LDM 时仅仅能够运行单张图像推理是不够的还需要复现论文中的evaluation pipeline即在标准测试集上进行批量推理并计算定量指标。Ref-LDM 项目中实际上已经提供了完整的 pipeline 组件主要由三个脚本构成align_face_image.py数据预处理人脸对齐作用使用 dlib 68 landmark对齐人脸裁剪为 512×512与 FFHQ 数据集对齐方式一致如果数据已经是FFHQ aligned可以跳过该步骤。运行方式python scripts/align_face_image.py \ -i input_dir \ -o aligned_dirinference_dataset.py批量推理该脚本完成加载数据集 ↓ 加载Ref-LDM模型 ↓ DDIM采样 ↓ 生成恢复结果支持三个测试集ffhqtest_severe ffhqtest_moderate celebatesteval.py指标评估该脚本计算以下指标指标说明IDSidentity similarityLPIPS感知差PSNR像素质量SSIM结构相似NIQE无参考质量MUSIQ视觉质量FID分布距离这三部分组合起来即可构成完整的 Ref-LDM evaluation pipeline数据集 ↓ align_face_image.py (optional) ↓ inference_dataset.py ↓ 生成 restored images ↓ eval.py ↓ metrics.csv2.数据集下载这里直接按照官方README下载数据集下载数据数据集内容FFHQ-Ref数据集FFHQ-Ref/ │ ├── reference_mapping/ │ ├── train_references.csv │ ├── val_references.csv │ └── test_references.csv │ ├── id_based_ffhq_split/ │ ├── train_image.txt │ ├── val_image.txt │ └── test_image.txt │ └── test_images/ ├── severe_degrad/ └── moderate_degrad/FFHQ-Ref包含20,405张高质量面部图像及其对应的参考图像。 它基于FFHQ数据集中的7万张图像使用ArcFace预测的面部身份标签构建而成。高质量图像从FFHQ数据集下载 (images1024x1024/)Huggingface上也有公开的数据集Iceclear/FFHQ-HQ1024 · Datasets at Hugging Facereference_mapping/CSV文件列出目标图像及其对应的参考图像id_based_ffhq_split/列出基于身份的 FFHQ 数据集 train/val/test 拆分图像的文本文件70,000 张图片为什么需要这样做以往的研究随机拆分了FFHQ数据集导致同一人的图像分布在训练和评估数据集中。我们提供基于身份的数据分配以解决这个问题。test_images/具有两种降解水平的低质量测试图像CelebA-test-ref数据集另一个基于参考的面部修复测试数据集包含2533张图像及其对应的参考图像。高质量图像从CelebAMask-HQ下载 (CelebAMask-HQ/CelebA-HQ-img/)test_references.csv目标图像及其对应参考图像列表celeba_test_images/包含低质量测试图像和高质量的真实地面图像3.实验设置论文附录明确写了图像分辨率512×512参考图数量固定为5inference 100 DDIM stepsclassifier-free guidance towards referenceascale of 1.5 towards reference imagesArcFace 人脸识别模型在评估阶段需要计算Identity Similarity (IDS)指标。该指标基于 ArcFace 人脸识别网络。我们下载 ArcFace ONNX 模型webface_r50_pfc.onnx并将其放置在pretrained/insightface_webface_r50.onnx该模型用于提取人脸 embedding 并计算余弦相似度其中恢复图像真实图像ArcFace embedding未完待续