import argparse, os, sys, glob
import torch
import clip
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid
import re

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler, DDIMSampler1D
from ldm.models.diffusion.plms import PLMSSampler
import safetensors
from pytorch_lightning import seed_everything
from safetensors.torch import load_file, save_file
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

# from dassl.data.datasets import Datum
from torchvision.transforms.functional import InterpolationMode


def _convert_image_to_rgb(image):
    return image.convert("RGB")

    
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def preprocess_clip(n_px=224):
    return Compose([
        Resize(n_px, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--config",
        type=str,
        nargs="?",
        default="",
        help="the config file"
    )

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--negative_prompt",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    )

    parser.add_argument(
        "--image_path",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    ) 

    parser.add_argument(
        "--output_folder_name",
        type=str,
        default="all_synthetic",
        help="the prompt to render"
    )

    parser.add_argument(
        "--concept_prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--ckpt",
        type=str,
        nargs="?",
        default="",
        help="the checkpoint file"
    )

    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=200,
        help="number of ddim sampling steps",
    )

    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--ipadapter",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--composable_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--interpolate",
        action='store_true',
        help="use interpolate",
    )

    parser.add_argument(
        "--negation_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )

    parser.add_argument(
        "--H",
        type=int,
        default=307200,
        help="image height, in pixel space",
    )

    parser.add_argument(
        "--n_samples",
        type=int,
        default=1,
        help="how many samples to produce for the given prompt",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=5.0,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    opt = parser.parse_args()
    seed_everything(opt.seed)


    config = OmegaConf.load(opt.config)  # TODO: Optionally download from same location as ckpt and chnage this logic
    model = load_model_from_config(config, opt.ckpt)  # TODO: check path

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    

    outpath = opt.outdir

    prompt = opt.prompt
    if opt.composable_diffusion or opt.negation_diffusion or opt.interpolate:
        prompt = prompt.split(',')

    if opt.negative_prompt != "":
        negative_prompt = opt.negative_prompt
    else:
        negative_prompt = ""
    mprompt = opt.concept_prompt
    start_noise = None
    x_start = None
    if opt.image_path != "":
        if opt.image_path.endswith('.safetensors'):
            start_noise_embed = safetensors.torch.load_file(opt.image_path)
            keyy = list(start_noise_embed.keys())[0]
            noisy = start_noise_embed[keyy].view(1, 1, -1).to(device)
        else:    
            preprocess_clip_ = preprocess_clip()
            if opt.image_path.endswith('.jpg'):
                if opt.ipadapter:
                    img = np.expand_dims(np.array(Image.open(opt.image_path)).transpose([2, 0, 1]), axis=0)
                    image, uncond_image_features = model.ip_adapter_model.get_image_embeds(img)
                    if opt.scale != 1.0:
                        image = torch.cat([image]*2)
                    if opt.n_samples > 1:
                        image = torch.cat([image]*opt.n_samples)
                else:
                    # Load and preprocess the image
                    image = preprocess_clip_(Image.open(opt.image_path)).unsqueeze(0).to(device)
            else:
                image = []
                for item in os.listdir(opt.image_path):
                    fp = os.path.join(opt.image_path, item)
                    img = preprocess_clip_(Image.open(fp)).unsqueeze(0).to(device)
                    image.append(img)
                image = torch.cat(image, 0)

            if '-learn-image' not in opt.ckpt and not opt.ipadapter:                
                clip_model, preprocess = clip.load('ViT-L/14', device=device)
                # Extract image features
                with torch.no_grad():
                    image_features = clip_model.encode_image(image)
                    image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize image features
                    print(image_features.shape)
                x_start = image_features.view(1, 1, -1).to(dtype=torch.float32)
            else:
                # print(image.shape)
                x_start = image
            init_prompt = f"/data/user/diffusers/examples/textual_inversion/sd_prompts/identity-{mprompt}-prompts/learned_embeds.safetensors"
            start_noise_embed = safetensors.torch.load_file(init_prompt)
            keyy = list(start_noise_embed.keys())[0]
            noisy = start_noise_embed[keyy].view(1, 1, -1).to(device)
        
        

    all_samples=list()
    with torch.no_grad():
        with model.ema_scope():
            uc = None
            if opt.scale != 1.0:
                uc = model.get_learned_conditioning(opt.n_samples * [negative_prompt])
            for n in trange(opt.n_iter, desc="Sampling"):
                if opt.composable_diffusion or opt.negation_diffusion or opt.interpolate:
                    c = model.get_learned_conditioning(prompt)    
                    if opt.scale != 1.0:
                        uc = model.get_learned_conditioning(len(prompt) * [""])
                else:
                    c = model.get_learned_conditioning(opt.n_samples * [prompt])
                if opt.H == 20480:
                    shape = [1, 15, opt.H]
                else:
                    shape = [1, 1, opt.H]
                
                if opt.composable_diffusion:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                x0=start_noise,
                                                eta=opt.ddim_eta,
                                                pcond=x_start,
                                                composable_diffusion=len(prompt)-1)
                elif opt.interpolate:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                x0=start_noise,
                                                eta=opt.ddim_eta,
                                                pcond=x_start,
                                                interpolate=len(prompt)-1)
                elif opt.negation_diffusion:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                x0=start_noise,
                                                eta=opt.ddim_eta,
                                                pcond=x_start,
                                                negation_diffusion=len(prompt)-1)
                else:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=c,
                                                    batch_size=opt.n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    x0=start_noise,
                                                    pcond=x_start,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta,)
       
                print(samples_ddim.shape)
                if opt.composable_diffusion or opt.interpolate:
                    samples_ddim = samples_ddim[-1]
                else:
                    samples_ddim = samples_ddim.mean(0)
                    # samples_ddim = x_start
                if opt.H != 768:
                    samples_ddim = samples_ddim[:,:,:768]
                
                adapter_weights_path = f"/data/user/diffusers/examples/textual_inversion/sd_prompts/identity-{mprompt}-prompts/learned_embeds.safetensors"
                
                load_adp_weights = load_file(adapter_weights_path)#["state_dict"]
                
                for k, ii in load_adp_weights.items():
                    load_adp_weights[k] = samples_ddim.reshape(ii.shape)
                
                dest_adapter_weights_path = adapter_weights_path.replace('learned_embeds', 'diff_learned_embeds').replace('sd_prompts', opt.output_folder_name)#.replace('.safetensors', '.pt')
                
                os.makedirs(os.path.dirname(dest_adapter_weights_path), exist_ok=True)
                # torch.save(load_adp_weights, dest_adapter_weights_path)
                safetensors.torch.save_file(load_adp_weights, dest_adapter_weights_path, metadata={"format": "pt"})
                