from timeit import default_timer as timer
from datetime import timedelta
from PIL import Image
import os
import numpy as np
from einops import rearrange
import torch
import torch.nn.functional as F
from torchvision import transforms
import transformers
from accelerate import Accelerator
from accelerate.utils import set_seed
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
import gradio as gr
import json
import tempfile
import argparse
from torchvision.utils import make_grid, save_image

from transformers import AutoTokenizer, PretrainedConfig
from sde_inversion import load_model, get_img_latent, get_text_embed

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import (
    AttnAddedKVProcessor,
    AttnAddedKVProcessor2_0,
    LoRAAttnAddedKVProcessor,
    LoRAAttnProcessor,
    LoRAAttnProcessor2_0,
    SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0")


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    elif model_class == "T5EncoderModel":
        from transformers import T5EncoderModel

        return T5EncoderModel
    else:
        raise ValueError(f"{model_class} is not supported.")

def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs

def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
    text_input_ids = input_ids.to(text_encoder.device)

    if text_encoder_use_attention_mask:
        attention_mask = attention_mask.to(text_encoder.device)
    else:
        attention_mask = None

    prompt_embeds = text_encoder(
        text_input_ids,
        attention_mask=attention_mask,
    )
    prompt_embeds = prompt_embeds[0]

    return prompt_embeds

# model_path: path of the model
# image: input image, have not been pre-processed
# save_lora_path: the path to save the lora
# prompt: the user input prompt
# lora_step: number of lora training step
# lora_lr: learning rate of lora training
# lora_rank: the rank of lora
def train_lora(image, prompt, model_path, vae_path, save_lora_path, lora_step, lora_lr, lora_rank, text_embedding=None):
    # initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision='fp16'
    )
    set_seed(0)

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        subfolder="tokenizer",
        revision=None,
        use_fast=False,
    )
    # initialize the model
    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
    text_encoder = text_encoder_cls.from_pretrained(
        model_path, subfolder="text_encoder", revision=None
    )
    if vae_path == "default":
        vae = AutoencoderKL.from_pretrained(
            model_path, subfolder="vae", revision=None
        )
    else:
        vae = AutoencoderKL.from_pretrained(vae_path)
    unet = UNet2DConditionModel.from_pretrained(
        model_path, subfolder="unet", revision=None
    )

    # set device and dtype
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    unet.to(device, dtype=torch.float16)
    vae.to(device, dtype=torch.float16)
    text_encoder.to(device, dtype=torch.float16)

    # initialize UNet LoRA
    unet_lora_attn_procs = {}
    for name, attn_processor in unet.attn_processors.items():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        else:
            raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")

        if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
            lora_attn_processor_class = LoRAAttnAddedKVProcessor
        else:
            lora_attn_processor_class = (
                LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
            )
        unet_lora_attn_procs[name] = lora_attn_processor_class(
            hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
        )

    unet.set_attn_processor(unet_lora_attn_procs)
    unet_lora_layers = AttnProcsLayers(unet.attn_processors)

    # Optimizer creation
    params_to_optimize = (unet_lora_layers.parameters())
    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=lora_lr,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-08,
    )

    lr_scheduler = get_scheduler(
        "constant",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=lora_step,
        num_cycles=1,
        power=1.0,
    )

    # prepare accelerator
    unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
    optimizer = accelerator.prepare_optimizer(optimizer)
    lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)

    # initialize text embeddings
    if text_embedding is None:
        with torch.no_grad():
            text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
            text_embedding = encode_prompt(
                text_encoder,
                text_inputs.input_ids,
                text_inputs.attention_mask,
                text_encoder_use_attention_mask=False
            )

    # initialize latent distribution
    image_transforms = transforms.Compose(
        [
            # transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomCrop(512),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    image = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16)
    image = image.unsqueeze(dim=0)
    latents_dist = vae.encode(image).latent_dist
    for _ in tqdm(range(lora_step), desc="training LoRA"):
        unet.train()
        model_input = latents_dist.sample() * vae.config.scaling_factor
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(model_input)
        bsz, channels, height, width = model_input.shape
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
        )
        timesteps = timesteps.long()

        # Add noise to the model input according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

        # Predict the noise residual
        model_pred = unet(noisy_model_input, timesteps, text_embedding).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(model_input, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    # save the trained lora
    unet = unet.to(torch.float32)
    # unwrap_model is used to remove all special modules added when doing distributed training
    # so here, there is no need to call unwrap_model
    # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
    LoraLoaderMixin.save_lora_weights(
        save_directory=save_lora_path,
        unet_lora_layers=unet_lora_layers,
        text_encoder_lora_layers=None,
    )

    return


from torch.optim.adam import Adam
import random

def text_inversion(img, text_embeddings, scheduler, model, lr=0.03, steps=5):
    for param in model.parameters():
        param.requires_grad = False
    text_embeddings.requires_grad = True
    optimizer = Adam([text_embeddings], lr=lr)
    with torch.autocast('cuda'):
        for i in tqdm(range(steps), desc="optiming"):
            timestep = random.randint(0, len(scheduler) - 1)
            alpha_prod_t = scheduler.alphas_cumprod[timestep]
            eps = torch.randn_like(img, device=img.device)
            noisey_img = alpha_prod_t ** (0.5) * img + (1 - alpha_prod_t) ** (0.5) * eps
            noise_pred = model(noisey_img, timestep, encoder_hidden_states=text_embeddings).sample
            loss = (noise_pred - eps).square().mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    text_embeddings.requires_grad = False
    return text_embeddings


def load_data(path='drag_data', type='Animals'):
    root = os.path.join(path, type)
    # files = ['cat_0', 'cat_3', 'dog_0', 'dog_1', 'elephant_2', 'lion']
    files = ['lion_2']

    datas = {}
    for path, _, _ in os.walk(root):
        if any(file in path for file in files):
            data = {}
            img = os.path.join(path, 'origin_image.png')
            mask_path = os.path.join(path, 'mask.png')

            with open(os.path.join(path, 'prompt.json'), 'r') as f:
                prompt = json.load(f)
            data['source'] = prompt['source']
            data['target'] = prompt['target']
            data['prompt'] = prompt['prompt']
            data['mask'] = mask_path
            datas[img] = data

    return datas


def sample(vae, tokenizer, text_encoder, unet, scheduler, prompt, lora_path, save_path, file_name, lora_scale):
    unet.load_attn_procs(lora_path)
    torch_device = 'cuda'

    prompt = [prompt] * 10
    height = 512  # default height of Stable Diffusion
    width = 512  # default width of Stable Diffusion
    num_inference_steps = 50  # Number of denoising steps
    guidance_scale = 3  # Scale for classifier-free guidance
    generator = torch.manual_seed(1)  # Seed generator to create the inital latent noise
    batch_size = len(prompt)

    text_embeddings = get_text_embed(prompt, tokenizer, text_encoder)
    uncond_embeddings = get_text_embed([""] * batch_size, tokenizer, text_encoder)

    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    latents = torch.randn(
        (batch_size, unet.in_channels, height // 8, width // 8),
        generator=generator,
    )
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma

    scheduler.set_timesteps(num_inference_steps)

    for t in tqdm(scheduler.timesteps):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, cross_attention_kwargs={"scale": lora_scale}).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = make_grid(image, nrow=5)

    os.makedirs(save_path, exist_ok=True)
    save_image(image, os.path.join(save_path, file_name))


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
            "--text_steps",
            type=int,
            default=0,
    )
    parser.add_argument(
            "--lora_steps",
            type=int,
            default=200,
    )
    parser.add_argument(
            "--lora_scale",
            type=float,
            default=1.,
    )
    opt = parser.parse_args()
    return opt


def main():
    data = load_data()
    opt = get_args()
    for image_path, item in data.items():
        prompt = item['prompt']

        image = np.array(Image.open(image_path))
        model_path = 'runwayml/stable-diffusion-v1-5'
        vae_path = 'default'
        lora_step = opt.lora_steps
        lora_lr = 2e-4
        lora_rank = 16

        vae, tokenizer, text_encoder, unet, scheduler = load_model()
        img = get_img_latent(image_path, vae, 512, 512)
        text_embeddings = get_text_embed(prompt, tokenizer, text_encoder)
        text_embeddings = text_inversion(img, text_embeddings, scheduler, unet, lr=0.03, steps=opt.text_steps)

        with tempfile.TemporaryDirectory() as temp_path:
            save_lora_path = temp_path
            train_lora(image, prompt, model_path, vae_path, save_lora_path, lora_step, lora_lr, lora_rank, text_embedding=text_embeddings.to(dtype=torch.float16))

            file_name = image_path.split('/')[-2]
            save_path = os.path.join(f'output/hp_lora/text_step={opt.text_steps}-lora_step={lora_step}-lora_scale={opt.lora_scale}')
            file_name = f'{file_name}.png'

            sample(vae, tokenizer, text_encoder, unet, scheduler, prompt, lora_path=save_lora_path, save_path=save_path, file_name=file_name, lora_scale=opt.lora_scale)


    # unet.load_attn_procs('./lora')


if __name__ == "__main__":
    main()

