LML-diffusion-sampler / scripts /StableDiffusion_COCO.py
王方懿康
Initial commit
ab2369a
import sys
import torch
import os
import json
import argparse
sys.path.append(os.getcwd())
from diffusers import StableDiffusionPipeline, DPMSolverMultistepLMScheduler, DDIMLMScheduler, PNDMScheduler, UniPCMultistepScheduler
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler
def main():
parser = argparse.ArgumentParser(description="sampling script for COCO14.")
parser.add_argument('--test_num', type=int, default=1000)
parser.add_argument('--start_index', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--num_inference_steps', type=int, default=20)
parser.add_argument('--guidance', type=float, default=7.5)
parser.add_argument('--sampler_type', type = str, default='ddim')
parser.add_argument('--model_id', type=str, default='/xxx/xxx/stable-diffusion-v1-5')
parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
parser.add_argument('--lamb', type=float, default=5.0)
parser.add_argument('--kappa', type=float, default=0.0)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
start_index = args.start_index
sampler_type = args.sampler_type
test_num = args.test_num
guidance_scale = args.guidance
num_inference_steps = args.num_inference_steps
lamb = args.lamb
kappa = args.kappa
device = args.device
model_id = args.model_id
# load model
sd_pipe = None
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None)
sd_pipe = sd_pipe.to(device)
print("sd model loaded")
if sampler_type in ['dpm_lm']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = True
elif sampler_type in ['dpm']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = False
elif sampler_type in ['dpm++']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver++"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = False
elif sampler_type in ['dpm++_lm']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver++"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = True
elif sampler_type in ['pndm']:
sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
elif sampler_type in ['ddim']:
sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = False
sd_pipe.scheduler.kappa = kappa
elif sampler_type in ['ddim_lm']:
sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = True
sd_pipe.scheduler.kappa = kappa
elif sampler_type in ['unipc']:
sd_pipe.scheduler = UniPCMultistepScheduler.from_config(sd_pipe.scheduler.config)
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
# COCO prompts
with open('/mnt/chongqinggeminiceph1fs/geminicephfs/mm-base-vision/pazelzhang/make_dataset/fid_3W_json.json') as fr:
COCO_prompts_dict = json.load(fr)
image_id = COCO_prompts_dict.keys()
with torch.no_grad():
for pi, key in enumerate(image_id):
if pi >= start_index and pi < start_index + test_num:
print(key)
print(COCO_prompts_dict[key])
prompt = COCO_prompts_dict[key]
negative_prompt = None
for seed in [1]:
generator = torch.Generator(device='cuda')
generator = generator.manual_seed(args.seed)
res = sd_pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, generator=generator).images[0]
res.save(os.path.join(save_dir, f"{pi:05d}_{key}_guidance{guidance_scale}_inference{num_inference_steps}_seed{seed}_{sampler_type}.jpg"))
print(f"{sampler_type}##{key},done")
if __name__ == '__main__':
main()