import os
import torch
torch.manual_seed(500)
import argparse
from omegaconf import OmegaConf

import sys
sys.path.append(".")

from ldm.util import instantiate_from_config

def create_argparser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--label", type=int, default=207, help="Label for pruning"
    )
    parser.add_argument(
        "--config", type=str, default="configs/latent-diffusion/cin256-v2.yaml"
    )
    parser.add_argument(
        "--ckpt", type=str, default="models/ldm/cin256-v2/model.ckpt"
    )
    parser.add_argument(
        "--save_path", type=str, default="xxx"
    )
    parser.add_argument(
        "--batch_size", type=int, default=4
    )
    parser.add_argument(
        "--num_samples", type=int, default=4
    )

    return parser


def save_torch_example(img, name, idxs, base_path):
    path = base_path + '/' + name + 'id_in_batch_is_'
    img = ((img + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    img = img.permute(0, 2, 3, 1)
    img = img.contiguous()
    img = img.cpu().numpy().astype(np.uint8)
    for i in idxs:
        pathi = path + str(i) + '.png'
        imgi = Image.fromarray(img[i])
        imgi.save(pathi)


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=True)
    model.cuda()
    model.eval()
    return model


def get_model(config_path, ckpt):
    config = OmegaConf.load(config_path)  
    model = load_model_from_config(config, ckpt)
    return model


from ldm.models.diffusion.ddim import DDIMSampler


args = create_argparser().parse_args()
os.makedirs(os.path.expanduser(args.save_path), exist_ok=True)
print("saving at:", args.save_path)
model = get_model(args.config, args.ckpt)

#################################################################
# DDDM 的额外添加
layer_idxs = []
expand_scale = model.model.diffusion_model.expand_scale
num_expand_layer = model.model.diffusion_model.num_expand_layer

for i in range(expand_scale):
    layer_idx_i = []
    for j in range(num_expand_layer):
        layer_idx_i.append([i])
    layer_idxs.append(layer_idx_i)
#################################################################

sampler = DDIMSampler(model)

import numpy as np 
from PIL import Image


ddim_steps = 20
ddim_eta = 0.0
scale = 3.0


all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(args.batch_size*[1000]).to(model.device)}
            )

        print(f"rendering {args.batch_size} examples of class '{args.label}' in {ddim_steps} steps and using s={scale:.2f}.")
        xc = torch.tensor(args.batch_size*[args.label])
        c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
        
        for i in range(args.num_samples // args.batch_size):
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                conditioning=c,
                                                batch_size=args.batch_size,
                                                shape=[3, 64, 64],
                                                verbose=False,
                                                unconditional_guidance_scale=scale,
                                                unconditional_conditioning=uc, 
                                                eta=ddim_eta,
                                                layer_idxs=layer_idxs, # DPDM
                                                exp_scale=expand_scale, # DPDM
                                                )

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            save_torch_example(x_samples_ddim, 'sample_'+str(i)+'_batch_', list(range(args.batch_size)), args.save_path)


print("sample finished")