当前位置: 首页 > news >正文

新型扩散模型加速框架Hyper-sd分享

Hyper-SD是由字节跳动开发的一款基于扩散模型的图像生成框架,旨在解决现有扩散模型在生成图像时计算成本高的问题。

Hyper-SD模型的主要目标是在保证或提升图像质量的同时,极大程度地减少图像生成所需的推理步骤。

Hyper-SD通过引入一种称为轨迹分段一致性蒸馏(Trajectory Segmented Consistency Distillation, TSCD)的技术能够在 1 到 8 步生成中实现 SOTA 级别的图像生成性能。

除了TSCD外,Hyper-SD还使用了分数蒸馏技术来增强一步生成性能,这有助于在极短的生成步骤下也能获得高质量的图像。

根据实验和用户评测,Hyper-SD 在 SDXL 和 SD1.5 两种架构上均表现出色,其在 1 到 8 步生成中实现了 SOTA 级别的图像生成性能,超过了现有的加速方法如 SDXL-Lightning。

github项目地址:https://hyper-sd.github.io/。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118

pip install diffusers sentencepiece modelscope huggingface-hub protobuf peft lora -i https://pypi.tuna.tsinghua.edu.cn/simple

3、Hyper-SD模型下载

git lfs install

git clone https://huggingface.co/ByteDance/Hyper-SD

4、FLUX.1-dev模型下载

git lfs install

git clone https://huggingface.co/black-forest-labs/FLUX.1-dev

、功能测试

1、运行测试

(1)Text-to-Image的python代码调用测试

import torch
from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline, DDIMScheduler, TCDScheduler, LCMScheduler, UNet2DConditionModel
from modelscope import snapshot_download
from modelscope.hub.file_download import model_file_download
from safetensors.torch import load_file
from huggingface_hub import hf_hub_downloaddef download_model_assets(model_repo: str, lora_repo: str, lora_file: str):base_model_id = snapshot_download(model_repo)lora_path = model_file_download(model_id=lora_repo, file_path=lora_file)return base_model_id, lora_pathdef create_pipeline(model_class, base_model_id, lora_path, device='cuda', dtype=torch.float16):pipe = model_class.from_pretrained(base_model_id, torch_dtype=dtype, variant="fp16").to(device)pipe.load_lora_weights(lora_path)pipe.fuse_lora()return pipedef generate_image(pipe, prompt, scheduler=None, num_steps=2, guidance_scale=0, eta=None, timesteps=None):if scheduler:pipe.scheduler = scheduler.from_config(pipe.scheduler.config)image = pipe(prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, eta=eta, timesteps=timesteps).images[0]image.save("output.png")return image# Example for FLUX.1-dev
base_model_id, lora_path = download_model_assets("AI-ModelScope/FLUX.1-dev", "ByteDance/Hyper-SD", 'Hyper-FLUX.1-dev-8steps-lora.safetensors')
pipe_flux = create_pipeline(FluxPipeline, base_model_id, lora_path)
generate_image(pipe_flux, "a photo of a cat, hold a sign 'I love Qwen'", num_steps=8, guidance_scale=3.5)# Example for SD3
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-3-medium-diffusers", "ByteDance/Hyper-SD", 'Hyper-SD3-8steps-CFG-lora.safetensors')
pipe_sd3 = create_pipeline(StableDiffusion3Pipeline, base_model_id, lora_path)
generate_image(pipe_sd3, "a photo of a cat", num_steps=8, guidance_scale=5.0)# Example for SDXL with 2-steps LoRA
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-xl-base-1.0", "ByteDance/Hyper-SD", 'Hyper-SDXL-2steps-lora.safetensors')
pipe_sdxl = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = DDIMScheduler.from_config(pipe_sdxl.scheduler.config, timestep_spacing="trailing")
generate_image(pipe_sdxl, "a photo of a cat", scheduler=scheduler, num_steps=2)# Example for SDXL with Unified LoRA
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-xl-base-1.0", "ByteDance/Hyper-SD", 'Hyper-SDXL-1step-lora.safetensors')
pipe_sdxl_uni = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = TCDScheduler.from_config(pipe_sdxl_uni.scheduler.config)
generate_image(pipe_sdxl_uni, "a photo of a cat", scheduler=scheduler, num_steps=1, eta=1.0)# Example for SDXL 1-step Unet
base_model_id, lora_path = download_model_assets("AI-ModelScope/stable-diffusion-xl-base-1.0", "ByteDance/Hyper-SD", 'Hyper-SDXL-1step-Unet.safetensors')
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(lora_path, device="cuda"))
pipe_sdxl_unet = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
scheduler = LCMScheduler.from_config(pipe_sdxl_unet.scheduler.config)
generate_image(pipe_sdxl_unet, "a photo of a cat", scheduler=scheduler, num_steps=1, timesteps=[800])# Example for SD1.5 with 2-steps LoRA
base_model_id = "runwayml/stable-diffusion-v1-5"
lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-SD15-2steps-lora.safetensors")
pipe_sd15 = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = DDIMScheduler.from_config(pipe_sd15.scheduler.config, timestep_spacing="trailing")
generate_image(pipe_sd15, "a photo of a cat", scheduler=scheduler, num_steps=2)# Example for SD1.5 with Unified LoRA
lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-SD15-1step-lora.safetensors")
pipe_sd15_uni = create_pipeline(DiffusionPipeline, base_model_id, lora_path)
scheduler = TCDScheduler.from_config(pipe_sd15_uni.scheduler.config)
generate_image(pipe_sd15_uni, "a photo of a cat", scheduler=scheduler, num_steps=1, eta=1.0)

未完......

更多详细的欢迎关注:杰哥新技术


http://www.mrgr.cn/news/52705.html

相关文章:

  • SQL Injection | SQL 注入 —— 时间盲注
  • 如何正确并优雅的使用Java中的临时文件目录
  • DeBiFormer:带有可变形代理双层路由注意力的视觉Transformer
  • vue + 百度地图GL版实现点聚合
  • C++算法练习-day6——203.移除链表元素
  • flask-socketio-+Nginx反向代理在消息收发和提醒上在使用
  • Scala的fold
  • 思想实验思维浅谈
  • GEE python: RUSLE土壤侵蚀模型的代码
  • 《深度学习》Dlib、OpenCV 轮廓绘制
  • snmpgetnext使用说明
  • STM32+PWM+DMA驱动WS2812
  • C语言 | Leetcode C语言题解之第491题非递减子序列
  • 苹果首部VR电影:《Submerged》的背后故事与沉浸式电影的未来
  • 语音信号去噪 Matlab语音信号去噪,GUI界面。分别添加了正弦噪声和高斯噪声,分别用了巴特沃斯低通滤波器和小波分解去噪。每步处理都可以播放出信号声音。
  • 3.Java入门笔记--基础语法
  • sankey.top - 桑基图/桑吉图/流程图/能量流/物料流/能量分析
  • 将SpringBoot项目部署到linux服务器使得本地可以访问
  • Mysql(4)—数据库索引
  • 2023年华为杯数学建模竞赛C题论文和代码