import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler


def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples")
    parser.add_argument("--ddim_steps", type=int, default=200, help="number of ddim sampling steps")
    parser.add_argument("--plms", action='store_true', help="use plms sampling")
    parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling")
    parser.add_argument("--n_iter", type=int, default=1, help="sample this often")
    parser.add_argument("--H", type=int, default=256, help="image height, in pixel space")
    parser.add_argument("--W", type=int, default=256, help="image width, in pixel space")
    parser.add_argument("--n_samples", type=int, default=5000, help="how many samples to produce for the given prompt")
    parser.add_argument("--scale", type=float, default=5.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
    parser.add_argument("--yaml", type=str, help='Path to the model YAML file')
    parser.add_argument("--ckpt", type=str, help='Path to the model checkpoint')
    parser.add_argument("--bs", type=int, default=300, help='mini batch size of sampling')
    parser.add_argument("--outname", type=str, help='file name of the generated .pt file')
    args = parser.parse_args()
    set_seeds(0)

    config = OmegaConf.load(args.yaml)  # TODO: Optionally download from same location as ckpt and chnage this logic
    model = load_model_from_config(config, args.ckpt)  # TODO: check path

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    if args.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    os.makedirs(args.outdir, exist_ok=True)
    outpath = args.outdir

    n_samples = args.n_samples

    # 'celeba_inputcaption.txt' is generated by taking the first description of each training image of MM-CelebA
    with open('celeba_inputcaption.txt') as f:
        caption = f.readlines()

    prompt = []
    for i in caption:
        prompt.append(i.replace("\n", ""))

    prompt = prompt[:n_samples]
    
    batch_size = args.bs

    iter = int(n_samples/batch_size)

    outname = args.outname

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))

    all_samples=list()
    with torch.no_grad():
        with model.ema_scope():
            uc = None
            if args.scale != 1.0:
                for idx in range(iter):
                    uc_batch = model.get_learned_conditioning(batch_size * [""])
                    c_batch = model.get_learned_conditioning(prompt[idx*batch_size: (idx+1)*batch_size])
                    shape = [4, args.H//8, args.W//8]
                    samples_ddim, _ = sampler.sample(S=args.ddim_steps,
                                                    conditioning=c_batch,
                                                    batch_size=batch_size,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=args.scale,
                                                    unconditional_conditioning=uc_batch,
                                                    eta=args.ddim_eta)

                    x_samples_ddim_batch1 = model.decode_first_stage(samples_ddim[:150])  # to prevent decoder cuda out of memory 
                    x_samples_ddim_batch2 = model.decode_first_stage(samples_ddim[150:])

                    all_samples.append(x_samples_ddim_batch1)
                    all_samples.append(x_samples_ddim_batch2)

    # additionally, save as grid
    grid = torch.vstack(all_samples)

    dic = {'image': grid}
    torch.save(dic, outname)
