π SenseFlow: Scaling Distribution Matching for Flow-based Text-to-Image Distillation
Xingtong Ge1,2, Xin Zhang2, Tongda Xu3, Yi Zhang4, Xinjie Zhang1, Yan Wang3, Jun Zhang1
1HKUST, 2SenseTime Research, 3Tsinghua University, 4CUHK MMLab
Abstract
The Distribution Matching Distillation (DMD) has been successfully applied to text-to-image diffusion models such as Stable Diffusion (SD) 1.5. However, vanilla DMD suffers from convergence difficulties on large-scale flow-based text-to-image models, such as SD 3.5 and FLUX. In this paper, we first analyze the issues when applying vanilla DMD on large-scale models. Then, to overcome the scalability challenge, we propose implicit distribution alignment (IDA) to constrain the divergence between the generator and the fake distribution. Furthermore, we propose intra-segment guidance (ISG) to relocate the timestep denoising importance from the teacher model. With IDA alone, DMD converges for SD 3.5; employing both IDA and ISG, DMD converges for SD 3.5 and FLUX.1 dev. Together with a scaled VFM-based discriminator, our final model, dubbed SenseFlow, achieves superior performance in distillation for both diffusion based text-to-image models such as SDXL, and flow-matching models such as SD 3.5 Large and FLUX.1 dev.
SenseFlow-FLUX.1 dev (supports 4β8-step generation)
SenseFlow-FLUX/diffusion_pytorch_model.safetensors: the DiT checkpoint.SenseFlow-FLUX/config.json: the config of DiT using in our model.
Usage
- prepare the base checkpoint of FLUX.1 dev to
Path/to/FLUX - Use
SenseFlow-FLUXto replace the transformer folderPath/to/FLUX/transformer, obtaining thePath/to/SenseFlow-FLUX.
Using the Euler sampler
import torch
from diffusers import FluxPipeline
from diffusers import FlowMatchEulerDiscreteScheduler
pipe = FluxPipeline.from_pretrained("Path/to/SenseFlow-FLUX", torch_dtype=torch.bfloat16).to("cuda")
prompt="A cat sleeping on a windowsill with white curtains fluttering in the breeze"
images = pipe(
prompt,
height=1024,
width=1024,
num_inference_steps=4,
max_sequence_length=512,
).images[0]
images.save("output.png")
Using the x0 sampler (similar to the LCMScheduler in diffusers)
import torch
from diffusers import FluxPipeline
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import Union, Tuple, Optional
class FlowMatchEulerX0Scheduler(FlowMatchEulerDiscreteScheduler):
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32) # Ensure precision
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
# 1. Compute x0 from model output (assuming model predicts noise)
x0 = sample - sigma * model_output
# 2. Add noise to x0 to get the sample for the next step
noise = torch.randn_like(sample)
prev_sample = (1 - sigma_next) * x0 + sigma_next * noise
prev_sample = prev_sample.to(model_output.dtype) # Convert back to original dtype
self._step_index += 1 # Move to next step
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
pipe = FluxPipeline.from_pretrained("Path/to/SenseFlow-FLUX", torch_dtype=torch.bfloat16).to("cuda")
pipe.scheduler = FlowMatchEulerX0Scheduler.from_config(pipe.scheduler.config)
prompt="A cat sleeping on a windowsill with white curtains fluttering in the breeze"
images = pipe(
prompt,
height=1024,
width=1024,
num_inference_steps=4,
max_sequence_length=512,
).images[0]
images.save("output.png")
DanceGRPO-SenseFlow (supports 4β8-step generation)
comming soon!
- Downloads last month
- -