import argparse, os, sys, glob
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
import re

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler, DDIMSampler1D
from ldm.models.diffusion.plms import PLMSSampler

from safetensors.torch import load_file, save_file


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--config",
        type=str,
        nargs="?",
        default="",
        help="the config file"
    )

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--negative_prompt",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    )

    parser.add_argument(
        "--concept_prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--ckpt",
        type=str,
        nargs="?",
        default="",
        help="the checkpoint file"
    )

    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=200,
        help="number of ddim sampling steps",
    )

    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )

    parser.add_argument(
        "--composable_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--image_path",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    )

    parser.add_argument(
        "--prefix",
        type=str,
        nargs="?",
        default="",
        help="the prefix name"
    )

    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )

    parser.add_argument(
        "--H",
        type=int,
        default=307200,
        help="image height, in pixel space",
    )

    parser.add_argument(
        "--n_samples",
        type=int,
        default=1,
        help="how many samples to produce for the given prompt",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=5.0,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    opt = parser.parse_args()
    seed_everything(opt.seed)


    config = OmegaConf.load(opt.config)  # TODO: Optionally download from same location as ckpt and chnage this logic
    model = load_model_from_config(config, opt.ckpt)  # TODO: check path

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    
    if opt.composable_diffusion:
        prompt = prompt.split(',')

    if opt.negative_prompt != "":
        negative_prompt = opt.negative_prompt
        prefix = 'neg_' + opt.prefix
    else:
        negative_prompt = ""
        prefix = opt.prefix
    
    start_noise = None
    x_start = None
    if opt.image_path != "":
        if opt.image_path.endswith('.safetensors'):
            start_noise_embed = safetensors.torch.load_file(opt.image_path)
            keyy = list(start_noise_embed.keys())[0]
            noisy = start_noise_embed[keyy].view(1, 1, -1).to(device)
        
        # prompt = ""
        start_noise = []
        for idx in range(opt.n_samples):
            start_noisy = model.q_sample(x_start=noisy, t=torch.tensor([688], device=device))
            start_noise.append(start_noisy)
        start_noise = torch.cat(start_noise)
        

    outpath = opt.outdir

    prompt = opt.prompt

    all_samples=list()
    with torch.no_grad():
        with model.ema_scope():
            uc = None
            if opt.scale != 1.0:
                uc = model.get_learned_conditioning(opt.n_samples * [negative_prompt])
            for n in trange(opt.n_iter, desc="Sampling"):
                if opt.composable_diffusion:
                    c = model.get_learned_conditioning(prompt)    
                    if opt.scale != 1.0:
                        uc = model.get_learned_conditioning(opt.n_samples * [""])
                else:
                    c = model.get_learned_conditioning(opt.n_samples * [prompt])
                if opt.H == 20480:
                    shape = [1, 15, opt.H]
                else:
                    shape = [1, 1, opt.H]
                samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                x0=start_noise,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                eta=opt.ddim_eta)
       
                print(samples_ddim.shape)
                samples_ddim = samples_ddim.mean(0)
                mprompt = opt.concept_prompt
                if opt.H != 768:
                    samples_ddim = samples_ddim[:,:,:768]
                
                
                # adapter_weights_path = f"/data/user/diffusers/examples/prompt_slider/outputs/{mprompt}-promptslider/learned_embeds.safetensors"
                adapter_weights_path = f"/data/user/diffusers/examples/lora_inversion/output/{mprompt}-slider_prompt/learned_embeds.safetensors"
                
                load_adp_weights = load_file(adapter_weights_path)#["state_dict"]

                for k, ii in load_adp_weights.items():
                    load_adp_weights[k] = samples_ddim.reshape(ii.shape)
                
                dest_adapter_weights_path = adapter_weights_path.replace('learned_embeds', prefix+'diff_learned_embeds')
                save_file(load_adp_weights, dest_adapter_weights_path)
                