import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

from torch.utils.data import DataLoader
from data.data_util import EvalDataset, CenterCropLongEdge

# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    
    model = instantiate_from_config(config.model)

    m, u = model.load_state_dict(sd, strict=False)

    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x


def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_checked_image, has_nsfw_concept


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )
    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="../image_result/val2014"
    )
    parser.add_argument(
        "--skip_grid",
        action='store_true',
        help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
    )
    parser.add_argument(
        "--skip_save",
        action='store_true',
        help="do not save individual samples. For speed measurements.",
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=10,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )
    parser.add_argument(
        "--dpm_solver",
        action='store_true',
        help="use dpm_solver sampling",
    )
    parser.add_argument(
        "--laion400m",
        action='store_true',
        help="uses the LAION400M model",
    )
    parser.add_argument(
        "--fixed_code",
        action='store_true',
        help="if enabled, uses the same starting code across samples ",
    )
    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )
    parser.add_argument(
        "--H",
        type=int,
        default=512,
        help="image height, in pixel space",
    )
    parser.add_argument(
        "--W",
        type=int,
        default=512,
        help="image width, in pixel space",
    )
    parser.add_argument(
        "--C",
        type=int,
        default=4,
        help="latent channels",
    )
    parser.add_argument(
        "--f",
        type=int,
        default=8,
        help="downsampling factor",
    )
    parser.add_argument(
        "--n_samples",
        type=int,
        default=4,
        help="how many samples to produce for each given prompt. A.k.a. batch size",
    )
    parser.add_argument(
        "--n_rows",
        type=int,
        default=0,
        help="rows in the grid (default: n_samples)",
    )
    parser.add_argument(
        "--scale",
        type=float,
        default=3,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    parser.add_argument(
        "--from-file",
        type=str,
        help="if specified, load prompts from this file",
    )
    parser.add_argument(
        "--config",
        type=str,
        default="configs/stable-diffusion/v1-inference.yaml",
        help="path to config which constructs model",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        default="../ckpt_sd1/model-v1-5.ckpt",
        help="path to checkpoint of model",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )
    parser.add_argument(
        "--precision",
        type=str,
        help="evaluate at this precision",
        choices=["full", "autocast"],
        default="autocast"
    )
    parser.add_argument(
        "--ours",
        action='store_true',
        help="ours",
    )
    parser.add_argument(
        "--ddim_steps_uncond",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--uncond_guide",
        type=float,
        default=0.5,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    parser.add_argument(
        "--coef_type",
        type=str,
        help="evaluate at this precision",
        choices=["lambda", "sigma_lambda", "square_lambda"],
        default="lambda"
    )
    
    parser.add_argument(
        "--layer_skip",
        action='store_true',
        help="layer_skip",
    )
    parser.add_argument(
        "--n_layer",
        type=int,
        default=3,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--target_count",
        type=int,
        default=3000,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--n_world",
        type=int,
        default=1,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--rank",
        type=int,
        default=0,
        help="number of ddim sampling steps",
    )
    opt = parser.parse_args()

    if opt.laion400m:
        print("Falling back to LAION 400M model...")
        opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
        opt.ckpt = "models/ldm/text2img-large/model.ckpt"
        opt.outdir = "outputs/txt2img-samples-laion400m"

    seed_everything(opt.seed)

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}")

    device = 'cuda'
    dtype = torch.float32
    model = model.to(device)
    model = model.to(dtype)
    model.to_device(device)

    if opt.ours:
        from ldm.models.diffusion.ddim import DDIMSampler
        from ldm.models.diffusion.plms import PLMSSampler
        from ldm.models.diffusion.dpm_solver import DPMSolverSampler
        if opt.plms:
            opt.outdir = os.path.join(opt.outdir, f"inv_ours_plms_{opt.scale:.1f}_{opt.coef_type}_{opt.ddim_steps}_{opt.uncond_guide}_{opt.target_count}")
        else:
            opt.outdir = os.path.join(opt.outdir, f"inv_ours_{opt.scale:.1f}_{opt.coef_type}_{opt.ddim_steps}_{opt.uncond_guide}_{opt.target_count}")
    else:
        from ldm.models.diffusion.ddim_vanilla import DDIMSampler
        from ldm.models.diffusion.plms_vanilla import PLMSSampler
        from ldm.models.diffusion.dpm_solver import DPMSolverSampler
        if opt.plms:
            opt.outdir = os.path.join(opt.outdir, f"vanilla_plms_{opt.scale:.1f}_{opt.ddim_steps}_{opt.target_count}")
        else:
            opt.outdir = os.path.join(opt.outdir, f"vanilla_{opt.scale:.1f}_{opt.ddim_steps}_{opt.target_count}")


    if opt.layer_skip:
        opt.outdir = opt.outdir + f"_layer_skip_{opt.n_layer}"

    if opt.dpm_solver:
        sampler = DPMSolverSampler(model)
    elif opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    os.makedirs(opt.outdir, exist_ok=True)
    outpath = opt.outdir

    opt.ref_data = 'coco2014'
    opt.ref_dir = '../COCO2014'
    opt.eval_res = 512

    dset2 = EvalDataset(data_name=opt.ref_data,
                        data_dir=opt.ref_dir,
                        data_type="real_images",
                        crop_long_edge=True,
                        resize_size=opt.eval_res,
                        resizer="lanczos",
                        normalize=True,
                        load_txt_from_file=False)

    dset2_dataloader = DataLoader(dataset=dset2,
                                  batch_size=opt.n_samples,
                                  shuffle=False,
                                  pin_memory=True,
                                  drop_last=False)
    batch_size = opt.n_samples

    sample_path = outpath
    os.makedirs(sample_path, exist_ok=True)
    base_count = 0

    start_code = None
    if opt.fixed_code:
        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)

    # encrypt
    model.cond_stage_model.transformer = model.cond_stage_model.transformer.to('cuda')
    model.cond_stage_model.transformer.text_model = model.cond_stage_model.transformer.text_model.to('cuda')
    model.cond_stage_model.transformer.text_model.embeddings = model.cond_stage_model.transformer.text_model.embeddings.to('cuda')
    model.cond_stage_model.transformer.text_model.encoder = model.cond_stage_model.transformer.text_model.encoder.to('cuda')
    model.cond_stage_model.transformer.text_model.final_layer_norm = model.cond_stage_model.transformer.text_model.final_layer_norm.to('cuda')
    
    model.first_stage_model.decoder = model.first_stage_model.decoder.to('cuda')
    model.first_stage_model.post_quant_conv = model.first_stage_model.post_quant_conv.to('cuda')
        
    sample_per_gpu = opt.target_count // opt.n_world
    start_count = sample_per_gpu * opt.rank
    end_count = start_count + sample_per_gpu
        
    total_prompts_list = []
    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                tic = time.time()
                
                uc = None
                if opt.scale != 1.0:
                    uc = model.get_learned_conditioning(batch_size * [""])
                
                for _, prompts in tqdm(dset2_dataloader):
                    
                    if base_count < start_count:
                        base_count += 1 * opt.n_samples
                        continue
                    
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    total_prompts_list += prompts
                    
                    c = model.get_learned_conditioning(prompts)

                    shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    S_uncond=opt.ddim_steps_uncond,
                                                    conditioning=c,
                                                    batch_size=opt.n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    uncond_guide=opt.uncond_guide,
                                                    coef_type=opt.coef_type,
                                                    layer_skip=opt.layer_skip,
                                                    n_layer=opt.n_layer,
                                                    eta=opt.ddim_eta,
                                                    x_T=start_code,
                                                    dtype=dtype)
                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                    x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                    x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

                    x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                    if not opt.skip_save:
                        for sample_idx, x_sample in enumerate(x_checked_image_torch):
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            img.save(os.path.join(sample_path, f"{base_count:05}.png"))
                            base_count += 1

                    if base_count >= end_count:
                        break
                            
                toc = time.time()

    with open(os.path.join(outpath,'captions.txt'), "w") as f:
        for prompt in total_prompts_list:
            f.write(prompt + "\n")


if __name__ == "__main__":
    main()
