新型扩散模型加速框架Hyper-sd分享
Hyper-SD是由字节跳动开发的一款基于扩散模型的图像生成框架,旨在解决现有扩散模型在生成图像时计算成本高的问题。
Hyper-SD模型的主要目标是在保证或提升图像质量的同时,极大程度地减少图像生成所需的推理步骤。
Hyper-SD通过引入一种称为轨迹分段一致性蒸馏(Trajectory Segmented Consistency Distillation, TSCD)的技术能够在 1 到 8 步生成中实现 SOTA 级别的图像生成性能。
除了TSCD外,Hyper-SD还使用了分数蒸馏技术来增强一步生成性能,这有助于在极短的生成步骤下也能获得高质量的图像。
根据实验和用户评测,Hyper-SD 在 SDXL 和 SD1.5 两种架构上均表现出色,其在 1 到 8 步生成中实现了 SOTA 级别的图像生成性能,超过了现有的加速方法如 SDXL-Lightning。
github项目地址:https://hyper-sd.github.io/。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118
pip install diffusers sentencepiece modelscope huggingface-hub protobuf peft lora -i https://pypi.tuna.tsinghua.edu.cn/simple
3、Hyper-SD模型下载:
git lfs install
git clone https://huggingface.co/ByteDance/Hyper-SD
4、FLUX.1-dev模型下载:
git lfs install
git clone https://huggingface.co/black-forest-labs/FLUX.1-dev
二、功能测试
1、运行测试:
(1)Text-to-Image的python代码调用测试
import torch
from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline, DDIMScheduler, TCDScheduler, LCMScheduler, UNet2DConditionModel
from modelscope import snapshot_download
from modelscope.hub.file_download import model_file_download
from safetensors.torch import load_file
from huggingface_hub import hf_hub_downloaddef download_model_assets(model_repo: str, lora_repo: str, lora_file: str):base_model_id = snapshot_download(model_repo)lora_path = model_file_download(model_id=lora_repo, file_path=lora_file)return base_model_id, lora_pathdef create_pipeline(model_class, base_model_id, lora_path, device='cuda', dtype=torch.float16):pipe = model_class.from_pretrained(base_model_id, torch_dtype=dtype, variant="fp16").to(device)pipe.load_lora_weights(lora_path)pipe.fuse_lora()return pipedef generate_image(pipe, prompt, scheduler=None, num_steps=2, guidance_scale=0, eta=None, timesteps=None):if scheduler:pipe.scheduler = scheduler.from_config(pipe.scheduler.config)image = pipe(prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, eta=eta, timesteps=timesteps).images[0]image.save("output.png")return image# Example for FLUX.1-dev
base_model_id, lora_path = download_model_assets("AI-ModelScope/FLUX.1-dev", "ByteDance/Hyper-SD", 'Hyper-FLUX.1-dev-8steps-lora.safetensors')
pipe_flux = create_pipeline(FluxPipeline, base_model_id, lora_path)
generate_image(pipe_flux, "a photo of a cat, hold a sign 'I love Qwen'", num_steps=8, guidance_scale=3.5)# Example for SD3
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-3-medium-diffusers", "ByteDance/Hyper-SD", 'Hyper-SD3-8steps-CFG-lora.safetensors')
pipe_sd3 = create_pipeline(StableDiffusion3Pipeline, base_model_id, lora_path)
generate_image(pipe_sd3, "a photo of a cat", num_steps=8, guidance_scale=5.0)# Example for SDXL with 2-steps LoRA
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-xl-base-1.0", "ByteDance/Hyper-SD", 'Hyper-SDXL-2steps-lora.safetensors')
pipe_sdxl = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = DDIMScheduler.from_config(pipe_sdxl.scheduler.config, timestep_spacing="trailing")
generate_image(pipe_sdxl, "a photo of a cat", scheduler=scheduler, num_steps=2)# Example for SDXL with Unified LoRA
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-xl-base-1.0", "ByteDance/Hyper-SD", 'Hyper-SDXL-1step-lora.safetensors')
pipe_sdxl_uni = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = TCDScheduler.from_config(pipe_sdxl_uni.scheduler.config)
generate_image(pipe_sdxl_uni, "a photo of a cat", scheduler=scheduler, num_steps=1, eta=1.0)# Example for SDXL 1-step Unet
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-xl-base-1.0", "ByteDance/Hyper-SD", 'Hyper-SDXL-1step-Unet.safetensors')
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(lora_path, device="cuda"))
pipe_sdxl_unet = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
scheduler = LCMScheduler.from_config(pipe_sdxl_unet.scheduler.config)
generate_image(pipe_sdxl_unet, "a photo of a cat", scheduler=scheduler, num_steps=1, timesteps=[800])# Example for SD1.5 with 2-steps LoRA
base_model_id = "runwayml/stable-diffusion-v1-5"
lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-SD15-2steps-lora.safetensors")
pipe_sd15 = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = DDIMScheduler.from_config(pipe_sd15.scheduler.config, timestep_spacing="trailing")
generate_image(pipe_sd15, "a photo of a cat", scheduler=scheduler, num_steps=2)# Example for SD1.5 with Unified LoRA
lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-SD15-1step-lora.safetensors")
pipe_sd15_uni = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = TCDScheduler.from_config(pipe_sd15_uni.scheduler.config)
generate_image(pipe_sd15_uni, "a photo of a cat", scheduler=scheduler, num_steps=1, eta=1.0)
未完......
更多详细的欢迎关注:杰哥新技术