import argparse, os, sys, glob
import torch
import clip
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid
import re

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler, DDIMSampler1D
from ldm.models.diffusion.plms import PLMSSampler
import safetensors
from pytorch_lightning import seed_everything
from safetensors.torch import load_file, save_file
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

# from dassl.data.datasets import Datum
from torchvision.transforms.functional import InterpolationMode


def _convert_image_to_rgb(image):
    return image.convert("RGB")

    
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def preprocess_clip(n_px=224):
    return Compose([
        Resize(n_px, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--config",
        type=str,
        nargs="?",
        default="",
        help="the config file"
    )

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--negative_prompt",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    )

    parser.add_argument(
        "--emb_path",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    ) 

    parser.add_argument(
        "--output_folder_name",
        type=str,
        default="all_synthetic",
        help="the prompt to render"
    )

    parser.add_argument(
        "--concept_prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--ckpt",
        type=str,
        nargs="?",
        default="",
        help="the checkpoint file"
    )

    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=200,
        help="number of ddim sampling steps",
    )

    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--ipadapter",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--composable_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--negation_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--interpolate",
        action='store_true',
        help="use interpolate",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )

    parser.add_argument(
        "--H",
        type=int,
        default=307200,
        help="image height, in pixel space",
    )

    parser.add_argument(
        "--n_samples",
        type=int,
        default=1,
        help="how many samples to produce for the given prompt",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=5.0,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    opt = parser.parse_args()
    seed_everything(opt.seed)


    config = OmegaConf.load(opt.config)  # TODO: Optionally download from same location as ckpt and chnage this logic
    model = load_model_from_config(config, opt.ckpt)  # TODO: check path

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    

    outpath = opt.outdir
    mprompt = opt.concept_prompt

    if opt.composable_diffusion or opt.negation_diffusion or opt.interpolate:
        if ',' not in opt.emb_path:
            raise NotImplementedError
        opt.emb_path, emb_path2 = opt.emb_path.split(',')

    start_noise = None
    x_start = None
    if opt.emb_path != "":
        if opt.emb_path.endswith('.safetensors'):
            start_noise_embed = safetensors.torch.load_file(opt.emb_path)
            keyy = list(start_noise_embed.keys())[0]
            x_start = start_noise_embed[keyy].view(1, 1, 1, -1).to(device)
            x_start = torch.cat([x_start]*opt.n_samples)
        elif opt.emb_path.endswith('.bin'):
            start_noise_embed = torch.load(opt.emb_path)
            keyy = list(start_noise_embed.keys())[0]
            x_start = start_noise_embed[keyy].view(1, 1, 1, -1).to(device)
            x_start = torch.cat([x_start]*opt.n_samples)
        else:
            x_start = torch.cat([torch.zeros(1, 1, 1, 768).to(device)]*opt.n_samples)
        if opt.composable_diffusion or opt.negation_diffusion or opt.interpolate:
            if emb_path2.endswith('.safetensors'):
                start_noise_embed = safetensors.torch.load_file(emb_path2)
                keyy = list(start_noise_embed.keys())[0]
                x_start2 = start_noise_embed[keyy].view(1, 1, 1, -1).to(device)
                x_start2 = torch.cat([x_start2]*opt.n_samples)
            elif emb_path2.endswith('.bin'):
                start_noise_embed = torch.load(emb_path2)
                keyy = list(start_noise_embed.keys())[0]
                x_start2 = start_noise_embed[keyy].view(1, 1, 1, -1).to(device)
                x_start2 = torch.cat([x_start2]*opt.n_samples)
            else:
                x_start2 = torch.cat([torch.zeros(1, 1, 1, 768).to(device)]*opt.n_samples)
            x_start = torch.cat([x_start, x_start2])
    else:
        raise NotImplementedError

    all_samples=list()
    with torch.no_grad():
        with model.ema_scope():
            uc = None
            
            for n in trange(opt.n_iter, desc="Sampling"):
                if opt.H == 20480:
                    shape = [1, 15, opt.H]
                else:
                    shape = [1, 1, opt.H]
                
                if opt.composable_diffusion:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=x_start,
                                                    batch_size=2,
                                                    shape=shape,
                                                    verbose=False,
                                                    x0=start_noise,
                                                    pcond=None,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta,
                                                    composable_diffusion=1)
                elif opt.interpolate:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=x_start,
                                                    batch_size=2,
                                                    shape=shape,
                                                    verbose=False,
                                                    x0=start_noise,
                                                    pcond=None,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta,
                                                    interpolate=1)
                elif opt.negation_diffusion:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=x_start,
                                                    batch_size=2,
                                                    shape=shape,
                                                    verbose=False,
                                                    x0=start_noise,
                                                    pcond=None,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta,
                                                    negation_diffusion=1)
                else:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=x_start,
                                                    batch_size=opt.n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    x0=start_noise,
                                                    pcond=None,
                                                    unconditional_guidance_scale=1.0,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta,)
        
                print(samples_ddim.shape)
                samples_ddim = samples_ddim.mean(0)
                    # samples_ddim = x_start
                
                
                adapter_weights_path = f"/data/user/diffusers/examples/textual_inversion/sd_prompts/identity-{mprompt}-prompts/learned_embeds.safetensors"
                
                load_adp_weights = load_file(adapter_weights_path)#["state_dict"]
                
                for k, ii in load_adp_weights.items():
                    load_adp_weights[k] = samples_ddim.reshape(ii.shape)
                
                dest_adapter_weights_path = adapter_weights_path.replace('learned_embeds', 'diff_learned_embeds').replace('sd_prompts', opt.output_folder_name)#.replace('.safetensors', '.pt')
                
                os.makedirs(os.path.dirname(dest_adapter_weights_path), exist_ok=True)
                # torch.save(load_adp_weights, dest_adapter_weights_path)
                safetensors.torch.save_file(load_adp_weights, dest_adapter_weights_path, metadata={"format": "pt"})
                