import os
import cv2
import inspect
import numpy as np
from typing import Optional, List, Dict

import click
from omegaconf import OmegaConf

import torch
import torch.utils.data
from torch.cuda.amp import autocast
import torch.nn.functional as F
import torch.utils.checkpoint

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
from tqdm.auto import tqdm
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection

from utils.util import get_time_string, get_function_args
from model.unet_2d_condition import UNet2DConditionModel
from model.pipeline import StableDiffusionPipeline
from dataset import StorySeqDataset

logger = get_logger(__name__)

class SampleLogger:
    def __init__(
        self,
        logdir: str,
        subdir: str = "sample",
        num_inference_steps: int = 20,
        guidance_scale: float = 2,
    ) -> None:
        self.guidance_scale = guidance_scale
        self.num_inference_steps = num_inference_steps

        self.logdir = os.path.join(logdir, subdir)
        os.makedirs(self.logdir)

    def log_sample_images(
        self, batch, visual_projection, image_encoder, text_encoder, pipeline: StableDiffusionPipeline, device: torch.device, step: int
    ):
        sample_seeds = torch.randint(0, 100000, (1,))
        sample_seeds = sorted(sample_seeds.numpy().tolist())
        self.sample_seeds = sample_seeds

        self.prompts = batch["prompt"]
        for idx, prompt in enumerate(tqdm(self.prompts, desc="Generating sample images")):
            for seed in self.sample_seeds:
                ref_image = batch["ref_image"][idx, :, :, :].unsqueeze(0)
                image = batch["image"][idx, :, :, :].unsqueeze(0)
                ref_image = ref_image.to(device=device)
                image = image.to(device=device)
                
                ref_img_feature = image_encoder(ref_image).last_hidden_state
                projected_ref_img_feature = visual_projection(ref_img_feature)
                ref_encoder_hidden_states = text_encoder(batch["ref_prompt_ids"].to(device=device)).last_hidden_state # B * 77 * 768
                
                # cross_frame_feature = torch.cat((projected_ref_img_feature, ref_encoder_hidden_states), dim=1)
                cross_frame_feature = projected_ref_img_feature
                
                generator = torch.Generator(device=device)
                generator.manual_seed(seed)
                sequence = pipeline(
                    cond=cross_frame_feature, # past frame
                    prompt=prompt, # current frame
                    height=image.shape[2],
                    width=image.shape[3],
                    generator=generator,
                    num_inference_steps=self.num_inference_steps,
                    guidance_scale=self.guidance_scale,
                    num_images_per_prompt=1,
                ).images[0]

                image = (image + 1.) / 2. # for visualization
                image = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
                ref_image = (ref_image + 1.) / 2. # for visualization
                ref_image = ref_image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
                
                cv2.imwrite(os.path.join(self.logdir, f"{step}_{idx}_{seed}.png"), image[:, :, ::-1] * 255)
                cv2.imwrite(os.path.join(self.logdir, f"{step}_{idx}_{seed}_ref.png"), ref_image[:, :, ::-1] * 255)
                sequence[0].save(os.path.join(self.logdir,f"{step}_{idx}_{seed}_output.png"))
                with open(os.path.join(self.logdir, f"{step}_{idx}_{seed}" + '.txt'), 'a') as f:
                    f.write(batch['ref_prompt'][idx])
                    f.write("\n")
                    f.write(batch['prompt'][idx])


def train(
    pretrained_model_path: str,
    logdir: str,
    train_steps: int = 300,
    validation_steps: int = 1000,
    validation_sample_logger: Optional[Dict] = None,
    gradient_accumulation_steps: int = 20,
    seed: Optional[int] = None,
    mixed_precision: Optional[str] = "fp16",
    train_batch_size: int = 6,
    val_batch_size: int = 1,
    learning_rate: float = 3e-5,
    scale_lr: bool = False,
    lr_scheduler: str = "constant",  # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
    lr_warmup_steps: int = 0,
    use_8bit_adam: bool = True,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_checkpointing: bool = False,
    checkpointing_steps: int = 1000,
):
    args = get_function_args()

    time_string = get_time_string()
    logdir += f"_{time_string}"

    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=mixed_precision,
    )
    if accelerator.is_main_process:
        os.makedirs(logdir, exist_ok=True)
        OmegaConf.save(args, os.path.join(logdir, "config.yml"))

    if seed is not None:
        set_seed(seed)

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_path,
        subfolder="tokenizer",
        use_fast=False,
    )

    # Load models and create wrapper for stable diffusion
    # text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_path, subfolder="CLIP")
    # image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="CLIP")
    text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_path, subfolder="text_encoder")
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
    unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
    noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    
    pipeline = StableDiffusionPipeline(
        vae=vae,
        text_encoder=text_encoder,
        image_encoder=image_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"),
    )
    
    pipeline.set_progress_bar_config(disable=True)

    if is_xformers_available():
        try:
            pipeline.enable_xformers_memory_efficient_attention()
        except Exception as e:
            logger.warning("Could not enable memory efficient attention. Make sure xformers is installed" f" correctly and a GPU is available: {e}")
    
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    # trainable_modules = ("attn2", ".attn1")
    # trainable_modules = ("attn1", "attn2", "ff")
    # trainable_modules = ("attn1", "attn2", "ff", "attn1_cross", "attn2_cross", "ff_cross")
    trainable_modules = ("attn1_cross", "attn2_cross", "ff_cross")
    # trainable_modules = ("attn2_cross")
    for name, module in unet.named_modules():
        if name.endswith(trainable_modules):
            # for n, m in module.named_modules():
            #     if name.endswith("attn2_cross") and n =="to_k" or name.endswith("attn2_cross") and n =="to_v":
            #         continue
            #     for params in m.parameters():
            #         params.requires_grad = True
            for params in module.parameters():
                params.requires_grad = True
    # for name, module in unet.named_modules():
    #     if name.endswith("attn2"):
    #         for n, m in module.named_modules():
    #             if n =="add_k_proj" or n =="add_v_proj" or n =="add_q_proj" or n =="add_out_proj":
    #                 for params in m.parameters():
    #                     params.requires_grad = True
        # for params in module.parameters():
        #     params.requires_grad = True
    
    if gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    if scale_lr:
        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes)

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
            )

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    params_to_optimize = unet.parameters()
    optimizer = optimizer_class(
        params_to_optimize,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

    train_dataset = StorySeqDataset(root="./Dataset/", dataset_name='train', tokenizer=tokenizer)
    val_dataset = StorySeqDataset(root="./Dataset/", dataset_name='test', tokenizer=tokenizer)
        
    print(train_dataset.__len__())
    print(val_dataset.__len__())
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=8)

    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=train_steps * gradient_accumulation_steps,
    )

    pipeline, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        pipeline, unet, optimizer, train_dataloader, lr_scheduler
    )
    
    accelerator.register_for_checkpointing(lr_scheduler)

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    vae.to(accelerator.device, dtype=weight_dtype)
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    image_encoder.to(accelerator.device, dtype=weight_dtype)
    
    visual_projection = image_encoder.visual_projection
    
    if accelerator.is_main_process:
        accelerator.init_trackers("video")  # , config=vars(args))

    step = 0

    if validation_sample_logger is not None and accelerator.is_main_process:
        validation_sample_logger = SampleLogger(**validation_sample_logger, logdir=logdir)

    progress_bar = tqdm(range(step, train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")

    def make_data_yielder(dataloader):
        while True:
            for batch in dataloader:
                yield batch
            accelerator.wait_for_everyone()

    train_data_yielder = make_data_yielder(train_dataloader)
    val_data_yielder = make_data_yielder(val_dataloader)

    while step < train_steps:
        # with autocast():
        batch = next(train_data_yielder)
        
        vae.eval()
        text_encoder.eval()
        image_encoder.eval()
        unet.train()
        
        ref_image = batch["ref_image"].to(dtype=weight_dtype)
        ref_prompt = batch["ref_prompt"]
        image = batch["image"].to(dtype=weight_dtype)
        mask = batch["mask"].to(dtype=weight_dtype)
        
        mask = mask[:, [0], :, :].repeat(1, 4, 1, 1) # 3 channels to 4 channels
        mask = F.interpolate(mask, scale_factor = 1 / 8., mode="bilinear", align_corners=False)
        b, c, h, w = image.shape

        ref_img_feature = image_encoder(ref_image).last_hidden_state
        projected_ref_img_feature = visual_projection(ref_img_feature)
        
        latents = vae.encode(image).latent_dist.sample()
        latents = latents * 0.18215
        # Sample noise that we'll add
        noise = torch.randn_like(latents) # [-1, 1]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b,), device=latents.device)
        timesteps = timesteps.long()
        if step % 100 == 0:
            timesteps = torch.randint(100, 101, (b,), device=latents.device)
            timesteps = timesteps.long()
        # Add noise according to the noise magnitude at each timestep
        
        noisy_latent = noise_scheduler.add_noise(latents, noise, timesteps)
        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device)).last_hidden_state # B * 77 * 768
        # ref_encoder_hidden_states = text_encoder(batch["ref_prompt_ids"].to(accelerator.device)).last_hidden_state # B * 77 * 768
        
        # cross_frame_feature = torch.cat((projected_ref_img_feature, ref_encoder_hidden_states), dim=1)
        cross_frame_feature = projected_ref_img_feature
        
        # Predict the noise residual
        model_pred = unet(noisy_latent, timesteps, cross_frame_feature, encoder_hidden_states).sample
        
        # Get the target for loss depending on the prediction type
        # type == epsilon here
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        loss = F.mse_loss(model_pred.float() * (1. - mask), target.float() * (1 - mask), reduction="mean")

        accelerator.backward(loss)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        fileName = open('loss_hf.txt', 'a')
        if step % 100 == 0:
            print(loss.item(), file=fileName)
            
        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            step += 1

            if accelerator.is_main_process:
                if validation_sample_logger is not None and step % validation_steps == 0:
                    unet.eval()
                    val_batch = next(val_data_yielder)
                    with autocast():
                        validation_sample_logger.log_sample_images(
                            batch = val_batch,
                            visual_projection = visual_projection, 
                            image_encoder = image_encoder,
                            text_encoder=text_encoder,
                            pipeline=pipeline,
                            device=accelerator.device,
                            step=step,
                        )
                if step % checkpointing_steps == 0:
                    accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
                        inspect.signature(accelerator.unwrap_model).parameters.keys()
                    )
                    extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
                    pipeline_save = StableDiffusionPipeline(
                        vae=vae,
                        text_encoder=text_encoder,
                        image_encoder=image_encoder,
                        tokenizer=tokenizer,
                        unet=accelerator.unwrap_model(unet, **extra_args),
                        scheduler=DDIMScheduler.from_pretrained(
                            pretrained_model_path,
                            subfolder="scheduler",
                        ),
                    )
                    checkpoint_save_path = os.path.join(logdir, f"checkpoint_{step}")
                    pipeline_save.save_pretrained(checkpoint_save_path)
        
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=step)

    accelerator.end_training()


@click.command()
@click.option("--config", type=str, default="config/sample.yml")
def run(config):
    train(**OmegaConf.load(config))

if __name__ == "__main__":
    run()


# CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py 
# tmux new -s mcaphf
# tmux attach -t mcaphf
# tmux detach  ctrl+b d
# tmux kill-session -t mcaphf

# tmux new -s hf
# tmux attach -t hf
# tmux detach  ctrl+b d
# tmux kill-session -t hf
# CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py 

# CUDA_VISIBLE_DEVICES=2,3 accelerate launch --main_process_port 6666 train.py
# CUDA_VISIBLE_DEVICES=2,3 accelerate launch train.py
# CUDA_VISIBLE_DEVICES=6 accelerate launch train.py