import os
import numpy as np
from PIL import Image
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision 
from torch import nn
import lightning
import wandb
from lightning_utilities.core.rank_zero import rank_zero_only
from diffusers import AutoencoderKLTemporalDecoder, StableVideoDiffusionPipeline
from models.ema import LitEma
from ..base import BaseSystem
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import text_preprocessing

def loss_fn(model_input, pred, noise, timesteps, scheduler):
    if scheduler.config.prediction_type == "epsilon":
        target = noise
    elif scheduler.config.prediction_type == "v_prediction":
        target = scheduler.get_velocity(model_input, noise, timesteps)
    else:
        raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
    loss = F.mse_loss(pred.float(), target.float(), reduction="mean")

    return loss

class SVDSystem(BaseSystem):
    def __init__(
        self,
        lr: float,
        num_video_frames: int,
        num_image_frames: int,
        text_model_path,
        dit: nn.Module,
        scheduler,
        base_model_id: str = "stabilityai/stable-video-diffusion-img2vid",
        variant: str = "fp16",
        cfg: float = 0.1, 
        num_inference_steps: int = 50,
        max_length: int = 120,
        ema_decay_rate: float = 0.9999,
        use_ema: bool = False,
        prediction_type: str = "epsilon",
        strict_loading=False,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["mv_model"])
        self.vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_id, subfolder="vae", variant=variant)
        self.dit = dit
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_path, add_bos_token=True, add_eos_token=True)
        self.text_encoder = AutoModelForCausalLM.from_pretrained(text_model_path).get_decoder() # ! default dtype is float32
        # text encoder to bf 
        self.vae.requires_grad_(False)
        self.text_encoder.eval()
        self.text_encoder.requires_grad_(False)
        self.scheduler = scheduler
        self.strict_loading=strict_loading

        self.use_ema = use_ema
        print(f"Using EMA: {use_ema}")
        if use_ema:
            print(f"Using EMA with decay rate {ema_decay_rate}")
            self.model_ema = LitEma(self.dit, decay=ema_decay_rate)
            self.model_ema.requires_grad_(False)

        self.trainable_parameters = [
            (self.dit.parameters(), 1.0),
        ]

        self.num_inference_steps = num_inference_steps
        self.conditioning_dropout_prob = cfg
        self.num_video_frames = num_video_frames
        self.num_image_frames = num_image_frames
        self.guidance_scale = 3.0

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.dit.parameters())
            self.model_ema.copy_to(self.dit)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.dit.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def setup(self, stage: str) -> None:
        super().setup(stage)
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
            self.log_video = self.tensorboard_log_video
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image
            self.log_video = self.wandb_log_video

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        for params, lr_scale in self.trainable_parameters:
            param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer
    
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        # pop the text encoder and vae
        for key in list(checkpoint['state_dict'].keys()):
            if "text_encoder" in key or "vae" in key:
                checkpoint['state_dict'].pop(key)
        return checkpoint


    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.dit)
    
    def on_test_epoch_end(self):
        fid = self.fid.compute()
        self.fid.reset()
        self.log("fid", fid, prog_bar=True)

    def tokenization(self, text, device, drop_text=False):
        if drop_text:
            if np.random.rand() < self.conditioning_dropout_prob:
                text = ""
        text_inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", max_length=self.hparams.max_length, truncation=True)
        return text_inputs.input_ids.to(device), text_inputs.attention_mask.to(device)

    @torch.no_grad()
    def split_video_image_prompts(self, video_image_prompts, device, drop_text=False, do_classifier_free_guidance=False):
        video_prompts = [video_image_prompts[i][0] for i in range(len(video_image_prompts))]
        images_prompts = [video_image_prompts[i][1:] for i in range(len(video_image_prompts))]

        if do_classifier_free_guidance:
            video_prompts.extend([""] * len(video_prompts)) # double the prompts

        b = len(video_prompts)
        video_cap_list, images_cap_list = [], []
        video_cap_mask_list, images_cap_mask_list = [], []
        for video_prompt in video_prompts:
            video_prompt = text_preprocessing(video_prompt)
            video_cap, video_cap_mask = self.tokenization(video_prompt, device, drop_text=drop_text)
            video_cap_list.append(video_cap)
            video_cap_mask_list.append(video_cap_mask)

        video_caps = torch.cat(video_cap_list, dim=0) # b x max_len
        video_cap_masks = torch.cat(video_cap_mask_list, dim=0) # b x max_len

        video_cap_feats = self.text_encoder(input_ids=video_caps).last_hidden_state
        cap_feats = repeat(video_cap_feats, 'b l c -> b f l c', f=self.num_video_frames)
        cap_masks = repeat(video_cap_masks, 'b l -> b f l', f=self.num_video_frames)

        if len(images_prompts[0]) > 0:
            if do_classifier_free_guidance:
                images_prompts.extend([[""] * len(images_prompts[0])] * len(images_prompts))
            for image_prompts in images_prompts:
                image_cap_list, image_cap_mask_list = [], []
                for image_prompt in image_prompts:
                    image_prompt = text_preprocessing(image_prompt)
                    image_cap, image_cap_mask = self.tokenization(image_prompt, device, drop_text=drop_text)
                    image_cap_list.append(image_cap)
                    image_cap_mask_list.append(image_cap_mask)
                images_cap_list.append(torch.cat(image_cap_list, dim=0)) # [b x max_len]
                images_cap_mask_list.append(torch.cat(image_cap_mask_list, dim=0))
            
            images_caps = torch.cat(images_cap_list, dim=0) # b x max_len
            images_cap_masks = torch.cat(images_cap_mask_list, dim=0) # b x max_len
            images_cap_feats = self.text_encoder(input_ids=images_caps).last_hidden_state
            images_cap_feats = rearrange(images_cap_feats, '(b f) l c -> b f l c', b=b) 
            images_cap_masks = rearrange(images_cap_masks, '(b f) l -> b f l', b=b)
            cap_feats = torch.cat([cap_feats, images_cap_feats], dim=1) # b x f1+f2 x max_length x c
            cap_masks = torch.cat([cap_masks, images_cap_masks], dim=1) # b x f1+f2 x max_length
        return cap_feats, cap_masks


    def training_step(self, batch, batch_idx):
        # text encoder dtype
        self.batch_idx = batch_idx
        model_input = batch['model_input']
        diffusion_video_list = [model_input[i][0] for i in range(len(model_input))] # [f x c x h x w, xxx]
        diffusion_images_list = [model_input[i][1:] for i in range(len(model_input))] # [[c x h x w, xxx], [xxx]]
        
        b = len(diffusion_video_list)
        device = diffusion_video_list[0].device
        dtype = diffusion_video_list[0].dtype

        video_image_prompts = batch['video_image_prompts'] # b x 9
        
        cap_feats, cap_masks = self.split_video_image_prompts(video_image_prompts, device, drop_text=True)

        video_latent_list, images_latent_list = [], []
        noise_video_latent_list, noise_images_latent_list = [], []

        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (b,), device=device).long()

        video_noise_list, images_noise_list = [], []
        for i in range(b):
            video_latent = self.tensor_to_vae_latent(diffusion_video_list[i], self.vae)
            noise = torch.randn_like(video_latent)
            noise_video_latent = self.scheduler.add_noise(video_latent, noise, timesteps[i]).to(dtype)
            video_latent_list.append(video_latent)
            noise_video_latent_list.append(noise_video_latent)
            video_noise_list.append(noise)

            image_latent_list, noise_image_latent_list, image_noise_list = [], [], []
            for j in range(len(diffusion_images_list[i])):
                image_latent = self.tensor_to_vae_latent(diffusion_images_list[i][j], self.vae)[0]
                noise = torch.randn_like(image_latent)
                image_noise_list.append(noise)
                noise_image_latent = self.scheduler.add_noise(image_latent, noise, timesteps[i]).to(dtype)
                # TODO: maybe use different timesteps
                image_latent_list.append(image_latent)
                noise_image_latent_list.append(noise_image_latent)

            images_latent_list.append(image_latent_list)
            noise_images_latent_list.append(noise_image_latent_list)
            images_noise_list.append(image_noise_list)

        video_pred_list, images_pred_list = self.dit(noise_video_latent_list, noise_images_latent_list, timesteps, cap_feats, cap_masks)


        video_loss_list, image_loss_list = [], []
        for i in range(b):
            video_loss = loss_fn(video_latent_list[i], video_pred_list[i], video_noise_list[i], timesteps[i], self.scheduler)
            image_loss = 0
            for j in range(len(diffusion_images_list[i])):
                image_loss += loss_fn(images_latent_list[i][j], images_pred_list[i][j], images_noise_list[i][j], timesteps[i], self.scheduler)
            video_loss_list.append(video_loss)
            image_loss_list.append(image_loss)
        mean_video_loss = sum(video_loss_list) / len(video_loss_list)
        mean_image_loss = sum(image_loss_list) / len(image_loss_list)
        loss = mean_video_loss + mean_image_loss

        self.log("train_loss", loss, prog_bar=True)
        self.log("video_loss", mean_video_loss, prog_bar=True)
        self.log("image_loss", mean_image_loss, prog_bar=True)
        return loss
    
    def inference_step(self, batch, batch_idx, dataloader_idx=0, stage = "val"):
        with self.ema_scope():
            video_pred_list, image_pred_list = self._generate_images(batch) # image in [0, 1]
        image_fp = self._save(video_pred_list, image_pred_list, batch["video_image_prompts"], f"{self.global_rank}_{batch_idx}", stage=stage)
        return image_fp

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "val")

    def test_step(self, batch, batch_idx, dataloader_idx):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "test")

    @torch.no_grad()
    def tensor_to_vae_latent(self, t, vae, needs_upcasting = True, micro_bs = 2):
        ori_shape_len = len(t.shape)
        model_dtype = next(vae.parameters()).dtype
        if ori_shape_len == 4:
            t = t.unsqueeze(1)
        video_length = t.shape[1]
        t = rearrange(t, "b f c h w -> (b f) c h w")
        if needs_upcasting:
            vae.to(dtype=torch.float32)
            t = t.to(torch.float32)
        # latents = vae.encode(t).latent_dist.sample()
        chunk_outs = []
        t_list = t.chunk(micro_bs, dim=0)
        for t_chunk in t_list:
            chunk_outs.append(vae.encode(t_chunk).latent_dist.sample())
        latents = torch.cat(chunk_outs, dim=0)

        if needs_upcasting:
            vae.to(dtype=model_dtype)
            latents = latents.to(model_dtype)
        latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
        if ori_shape_len == 4:
            latents = latents.squeeze(1)
        latents = latents * vae.config.scaling_factor
        return latents
    
    @torch.no_grad()
    def decode(self, latents, vae, num_frames, needs_upcasting = True):
        latents = latents / vae.config.scaling_factor
        dtype = next(vae.parameters()).dtype
        if needs_upcasting:
            vae.to(dtype=torch.float32)
            latents = latents.to(torch.float32)
        video = vae.decode(latents, return_dict=False, num_frames=num_frames)[0]
        if needs_upcasting:
            vae.to(dtype=dtype)
            video = video.to(dtype)
        return video
    
    @torch.no_grad()
    def _generate_images(self, batch, generator=None):
        model_input = batch['model_input']
        diffusion_video_list = [model_input[i][0] for i in range(len(model_input))]
        diffusion_images_list = [model_input[i][1:] for i in range(len(model_input))]
        video_image_prompts = batch['video_image_prompts']
        dtype = diffusion_video_list[0].dtype

        do_classifier_free_guidance = self.conditioning_dropout_prob > 0.0
        cap_feats, cap_masks = self.split_video_image_prompts(video_image_prompts, diffusion_video_list[0].device, drop_text=False, do_classifier_free_guidance=do_classifier_free_guidance)

        b = len(diffusion_video_list)
        device = diffusion_video_list[0].device

        self.scheduler.set_timesteps(self.num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        video_latent_list, images_latent_list = [], [] # [ f x c x h x w, f x c x h x w, ...], [[c x h x w, c x h x w, ...], ...
        for i in range(b):
            f, c, h, w = diffusion_video_list[i].shape
            video_latent = torch.randn((f, 4, h//8, w//8), device=device, dtype=dtype)
            video_latent_list.append(video_latent)
            image_latent_list = []
            for j in range(len(diffusion_images_list[i])):
                c, h, w = diffusion_images_list[i][j].shape
                image_latent = torch.randn((4, h//8, w//8), device=device, dtype=dtype)
                image_latent_list.append(image_latent)
            images_latent_list.append(image_latent_list)
            
        if do_classifier_free_guidance:
            video_latent_list = video_latent_list * 2
            images_latent_list = images_latent_list * 2

        for idx, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            for i in range(b):
                video_latent_list[i] = self.scheduler.scale_model_input(video_latent_list[i], t)
                for j in range(len(images_latent_list[i])):
                    images_latent_list[i][j] = self.scheduler.scale_model_input(images_latent_list[i][j], t)

            video_pred_list, images_pred_list = self.dit(
                video_latent_list,
                images_latent_list,
                t,
                cap_feats,
                cap_masks,
            )
            for i in range(b):
                if do_classifier_free_guidance:
                    video_noise_pred_text, video_noise_pred_uncond = video_pred_list[i], video_pred_list[i + b]        
                    video_noise_pred = video_noise_pred_uncond + self.guidance_scale * (video_noise_pred_text - video_noise_pred_uncond)
                    video_latent_list[i] = self.scheduler.step(video_noise_pred, t, video_latent_list[i], return_dict=False)[0]
                    video_latent_list[i+b] = video_latent_list[i]
                    for j in range(len(images_latent_list[i])):
                        image_noise_pred_text, image_noise_pred_uncond = images_pred_list[i][j], images_pred_list[i + b][j]
                        image_noise_pred = image_noise_pred_uncond + self.guidance_scale * (image_noise_pred_text - image_noise_pred_uncond)
                        images_latent_list[i][j] = self.scheduler.step(image_noise_pred, t, images_latent_list[i][j], return_dict=False)[0]
                        images_latent_list[i+b][j] = images_latent_list[i][j]
                else:
                    video_latent_list[i] = self.scheduler.step(video_pred_list[i], t, video_latent_list[i], return_dict=False)[0]
                    for j in range(len(images_latent_list[i])):
                        images_latent_list[i][j] = self.scheduler.step(images_pred_list[i][j], t, images_latent_list[i][j], return_dict=False)[0]
        # decode the latents
        video_list = [] 
        images_list = []
        for i in range(b):
            video_latent = video_latent_list[i]
            # video = self.vae.decode(video_latent / self.vae.config.scaling_factor, return_dict=False, num_frames=self.num_video_frames)[0]
            video = self.decode(video_latent, self.vae, self.num_video_frames)
            video = torch.clamp(video * 0.5 + 0.5, 0, 1)
            video = (video * 255).to(torch.uint8)
            video_list.append(video)
            image_list = []
            for j in range(len(images_latent_list[i])):
                image_latent = images_latent_list[i][j].unsqueeze(0)
                # image = self.vae.decode(image_latent / self.vae.config.scaling_factor, return_dict=False, num_frames=1)[0]
                image = self.decode(image_latent, self.vae, 1)
                image = torch.clamp(image * 0.5 + 0.5, 0, 1)
                image_list.append(image)
            images_list.append(image_list)

        return video_list, images_list  
    

    @torch.no_grad()
    def _save(self, video_list, images_list, prompt_list, batch_idx, stage="validation"):
        save_dir = self.save_dir
        for idx, video in enumerate(video_list):
            caption = prompt_list[idx][0]
            self.log_video(
                tag=f"{stage}_videos/{batch_idx}_{idx}",
                video_tensor=video[None, ].cpu().numpy(),
                caption=caption,
            )
            # save video
            video_fp = os.path.join(save_dir, f"{stage}_video_{batch_idx}_{idx}.mp4")
            video = rearrange(video, "f c h w -> f h w c").cpu()
            torchvision.io.write_video(video_fp, video, fps=4)

        if len(images_list[0]) > 0:
            for idx, image_list in enumerate(images_list):
                max_h = max(image.shape[1] for image in image_list)
                max_w = max(image.shape[2] for image in image_list)
                for i, image in enumerate(image_list):
                    pad_image = F.pad(image[None, ], (0, max_w - image.shape[2], 0, max_h - image.shape[1]))[0]
                    image_list[i] = pad_image
                image_list = torch.cat(image_list, dim=0) # [n x c x h x w]
                grid = torchvision.utils.make_grid(image_list, nrow=8).cpu()
                self.log_image(
                    tag="{}_images/{}_{}".format(stage, batch_idx, idx),
                    image_tensor=grid.numpy(),
                    caption="\n".join(prompt_list[idx][1:])
                )
                # save image
                image_fp = os.path.join(save_dir, f"{stage}_image_{batch_idx}_{idx}.png")
                torchvision.utils.save_image(grid, image_fp)
        

    def tensorboard_log_image(self, tag: str, image_tensor, caption=None):
        self.logger.experiment.add_image(
            tag,
            image_tensor,
            self.trainer.global_step,
        )

    def wandb_log_image(self, tag: str, image_tensor, caption=None):
        if image_tensor.shape[0] == 3:
            image_tensor = image_tensor.transpose(1, 2, 0)
        image_dict = {
            tag: wandb.Image(image_tensor, caption=caption),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )

    def wandb_log_video(self, tag: str, video_tensor, caption=None):
        video_dict = {
            tag: wandb.Video(video_tensor[0], 
                             caption=caption, fps=4),
        }
        self.logger.experiment.log(
            video_dict,
            step=self.trainer.global_step,
        )
    def tensorboard_log_video(self, tag: str, video_tensor, caption=None):
        self.logger.experiment.add_video(
            tag,
            video_tensor,
            self.trainer.global_step,
            fps=4,
        )

