|
|
import sys |
|
|
|
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
sys.path.append(os.getcwd()) |
|
|
|
|
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepLMScheduler, DDIMLMScheduler, PNDMScheduler, UniPCMultistepScheduler |
|
|
|
|
|
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler |
|
|
from scheduler.scheduling_ddim_lm import DDIMLMScheduler |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="sampling script for COCO14.") |
|
|
parser.add_argument('--test_num', type=int, default=1000) |
|
|
parser.add_argument('--start_index', type=int, default=0) |
|
|
parser.add_argument('--seed', type=int, default=1) |
|
|
parser.add_argument('--num_inference_steps', type=int, default=20) |
|
|
parser.add_argument('--guidance', type=float, default=7.5) |
|
|
parser.add_argument('--sampler_type', type = str, default='ddim') |
|
|
parser.add_argument('--model_id', type=str, default='/xxx/xxx/stable-diffusion-v1-5') |
|
|
parser.add_argument('--save_dir', type=str, default='/xxx/xxx') |
|
|
parser.add_argument('--lamb', type=float, default=5.0) |
|
|
parser.add_argument('--kappa', type=float, default=0.0) |
|
|
parser.add_argument('--device', type=str, default='cuda') |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
start_index = args.start_index |
|
|
sampler_type = args.sampler_type |
|
|
test_num = args.test_num |
|
|
guidance_scale = args.guidance |
|
|
num_inference_steps = args.num_inference_steps |
|
|
lamb = args.lamb |
|
|
kappa = args.kappa |
|
|
device = args.device |
|
|
model_id = args.model_id |
|
|
|
|
|
|
|
|
|
|
|
sd_pipe = None |
|
|
|
|
|
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None) |
|
|
sd_pipe = sd_pipe.to(device) |
|
|
print("sd model loaded") |
|
|
|
|
|
|
|
|
if sampler_type in ['dpm_lm']: |
|
|
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
sd_pipe.scheduler.config.solver_order = 3 |
|
|
sd_pipe.scheduler.config.algorithm_type = "dpmsolver" |
|
|
sd_pipe.scheduler.lamb = lamb |
|
|
sd_pipe.scheduler.lm = True |
|
|
elif sampler_type in ['dpm']: |
|
|
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
sd_pipe.scheduler.config.solver_order = 3 |
|
|
sd_pipe.scheduler.config.algorithm_type = "dpmsolver" |
|
|
sd_pipe.scheduler.lamb = lamb |
|
|
sd_pipe.scheduler.lm = False |
|
|
elif sampler_type in ['dpm++']: |
|
|
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
sd_pipe.scheduler.config.solver_order = 3 |
|
|
sd_pipe.scheduler.config.algorithm_type = "dpmsolver++" |
|
|
sd_pipe.scheduler.lamb = lamb |
|
|
sd_pipe.scheduler.lm = False |
|
|
elif sampler_type in ['dpm++_lm']: |
|
|
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
sd_pipe.scheduler.config.solver_order = 3 |
|
|
sd_pipe.scheduler.config.algorithm_type = "dpmsolver++" |
|
|
sd_pipe.scheduler.lamb = lamb |
|
|
sd_pipe.scheduler.lm = True |
|
|
elif sampler_type in ['pndm']: |
|
|
sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
elif sampler_type in ['ddim']: |
|
|
sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
sd_pipe.scheduler.lamb = lamb |
|
|
sd_pipe.scheduler.lm = False |
|
|
sd_pipe.scheduler.kappa = kappa |
|
|
elif sampler_type in ['ddim_lm']: |
|
|
sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config) |
|
|
sd_pipe.scheduler.lamb = lamb |
|
|
sd_pipe.scheduler.lm = True |
|
|
sd_pipe.scheduler.kappa = kappa |
|
|
elif sampler_type in ['unipc']: |
|
|
sd_pipe.scheduler = UniPCMultistepScheduler.from_config(sd_pipe.scheduler.config) |
|
|
|
|
|
save_dir = args.save_dir |
|
|
if not os.path.exists(save_dir): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
with open('/mnt/chongqinggeminiceph1fs/geminicephfs/mm-base-vision/pazelzhang/make_dataset/fid_3W_json.json') as fr: |
|
|
COCO_prompts_dict = json.load(fr) |
|
|
image_id = COCO_prompts_dict.keys() |
|
|
with torch.no_grad(): |
|
|
for pi, key in enumerate(image_id): |
|
|
if pi >= start_index and pi < start_index + test_num: |
|
|
print(key) |
|
|
print(COCO_prompts_dict[key]) |
|
|
prompt = COCO_prompts_dict[key] |
|
|
negative_prompt = None |
|
|
|
|
|
for seed in [1]: |
|
|
generator = torch.Generator(device='cuda') |
|
|
generator = generator.manual_seed(args.seed) |
|
|
res = sd_pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, generator=generator).images[0] |
|
|
res.save(os.path.join(save_dir, f"{pi:05d}_{key}_guidance{guidance_scale}_inference{num_inference_steps}_seed{seed}_{sampler_type}.jpg")) |
|
|
print(f"{sampler_type}##{key},done") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |