import os
from lightning.pytorch.core.optimizer import LightningOptimizer
import numpy as np
from PIL import Image
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
import torchvision 
from torch import nn
import lightning
from lightning import LightningModule
import wandb
from lightning_utilities.core.rank_zero import rank_zero_only
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import EulerDiscreteScheduler, UNetSpatioTemporalConditionModel, AutoencoderKLTemporalDecoder, StableVideoDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
from src.models.ema import LitEma
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from ..base import BaseSystem
from safetensors import safe_open
from torch.cuda.amp import autocast
from torch import inf
from src.utils import RankedLogger

log = RankedLogger(__name__, rank_zero_only=True)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def get_grad_norm(
        parameters, norm_type: float = 2.0) -> torch.Tensor:
    r"""
    Copy from torch.nn.utils.clip_grad_norm_

    Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device
    if norm_type == inf:
        norms = [g.detach().abs().max().to(device) for g in grads]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
    return total_norm

def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
    r"""
    Copy from torch.nn.utils.clip_grad_norm_

    Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device

    if norm_type == inf:
        norms = [g.detach().abs().max().to(device) for g in grads]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
  

    if clip_grad:
        if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
            raise RuntimeError(
                f'The total norm of order {norm_type} for gradients from '
                '`parameters` is non-finite, so it cannot be clipped. To disable '
                'this error and scale the gradients by the non-finite norm anyway, '
                'set `error_if_nonfinite=False`')
        clip_coef = max_norm / (total_norm + 1e-6)
        # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
        # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
        # when the gradients do not reside in CPU memory.
        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
        for g in grads:
            g.detach().mul_(clip_coef_clamped.to(g.device))

    return total_norm

def _get_add_time_ids(
    unet,
    fps,
    motion_bucket_id,
    noise_aug_strength,
    dtype,
    batch_size,
):
    add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
    passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids)
    expected_add_embed_dim = unet.add_embedding.linear_1.in_features

    if expected_add_embed_dim != passed_add_embed_dim:
        raise ValueError(
            f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
        )
    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
    add_time_ids = add_time_ids.repeat(batch_size, 1)
    return add_time_ids

def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
    """Draws samples from an lognormal distribution."""
    u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
    return torch.distributions.Normal(loc, scale).icdf(u).exp()


class SVDSystem(BaseSystem):
    def __init__(
        self,
        lr: float,
        mv_model: torch.nn.Module,
        recon_model: torch.nn.Module,
        base_model_id: str = "stabilityai/stable-video-diffusion-img2vid",
        variant: str = "fp16",
        cfg: float = 0.1, 
        report_to: str = "wandb",
        ema_decay_rate: float = 0.9999,
        compile: bool = False,
        use_ema: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["mv_model", "recon_model"])

        self.pipeline = StableVideoDiffusionPipeline.from_pretrained(base_model_id, subfolder="pipeline", variant=variant)
        # self.pipeline.enable_model_cpu_offload()

        self.scheduler = self.pipeline.scheduler
        self.image_encoder = self.pipeline.image_encoder
        self.vae = self.pipeline.vae
        self.feature_extractor = self.pipeline.feature_extractor
        self.image_processor = self.pipeline.image_processor
        self.image_encoder.requires_grad_(False)
        self.vae.requires_grad_(False)

        self.pipeline.unet.requires_grad_(True)

        self.mv_model = mv_model(self.pipeline.unet)
        self.mv_model.set_use_memory_efficient_attention_xformers(True)
        self.mv_model.set_gradient_checkpointing(True)
        self.mv_model.train()
        self.recon_model = recon_model

        model_path = "pretrain/LGM/model_fp16.safetensors"        
        tensors = {}

        with safe_open(model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                tensors[key] = f.get_tensor(key)
                # if key == "unet.conv_in.weight":
                #     rgb_weight = tensors[key][:, :3, :, :]
                #     ray_weight = tensors[key][:, 3:, :, :]
                #     shape_list = list(tensors[key].shape)
                #     shape_list[1] = 4 + 6
                #     new_weight = torch.zeros(shape_list)
                #     new_weight[:, 4:, :, :] = ray_weight
                #     tensors[key] = new_weight

        missing_keys, unexpected_keys = self.recon_model.load_state_dict(tensors, strict=False)
        print(missing_keys, unexpected_keys)


        # metric objects for calculating and averaging accuracy across batches
        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim =StructuralSimilarityIndexMeasure() 
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) 

        self.use_ema = use_ema
        if use_ema:
            self.model_ema = LitEma(self.mv_model, decay=ema_decay_rate)

        self.trainable_parameters = [
            (self.mv_model.parameters(), 1.0),
            (self.recon_model.parameters(), 1.0),
        ]

        self.num_inference_steps = 25
        self.min_guidance_scale = 1.0
        self.max_guidance_scale = 3.0
        self.conditioning_dropout_prob = cfg

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def setup(self, stage: str) -> None:
        super().setup(stage)
        if self.hparams.compile and stage == "fit":
            self.mv_model = torch.compile(self.mv_model)
        
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image
            self.logger.watch(self.mv_model, log_graph=False)


    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        for params, lr_scale in self.trainable_parameters:
            param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def forward(self, latents, timestep, encoder_hidden_states, added_time_ids, cond):
        return self.mv_model(latents, timestep, encoder_hidden_states, added_time_ids, cond)
    
    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    def training_step(self, batch, batch_idx):

        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w #  [-1, 1]
        bsz, m, c, h, w = diffusion_images.shape
        dtype = diffusion_images.dtype
        latents = self.tensor_to_vae_latent(diffusion_images, self.vae)

        # Sample a random timestep for each image
        # P_mean=0.7 P_std=1.6
        sigmas = rand_log_normal(shape=[bsz, 1, 1, 1, 1], loc=1.0, scale=1.6).to(latents.device)
        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)

        timesteps = torch.Tensor(
                    [0.25 * sigma.log() for sigma in sigmas]).to(latents.device)
        
        render_x1 = self.recon_model(diffusion_images, batch, timesteps)
        recon_images = diffusion_images * 0.5 + 0.5
        recon_loss = torch.abs(render_x1['images_pred'] - recon_images).mean()
        recon_loss += self.lpips(
            render_x1['images_pred'].reshape(-1, 3, h, w).float(),
            recon_images.reshape(-1, 3, h, w).float()
        )
        depth_loss = torch.abs(render_x1['depths_pred'] - batch['depths']).mean()
        # recon_loss += 0.8 * self.ssim(
        #     render_x1['images_pred'].reshape(-1, 3, h, w).float(),
        #     recon_images.reshape(-1, 3, h, w).float()
        # )
        self.log("train_recon_loss", recon_loss, on_step=True, on_epoch=False, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log("train_depth_loss", depth_loss, on_step=True, on_epoch=False, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        # print("GPU memory:", torch.cuda.memory_allocated() / 1024 ** 3)
        return recon_loss + depth_loss

    def inference_step(self, batch, batch_idx, dataloader_idx=0, stage = "val"):
        with self.ema_scope():
            preds = self._generate_images(batch) # image in [0, 1] 1 x 8 x 3 x 512 x 512

        image_fp = self._save_image(preds, batch, batch["prompt"], f"{batch_idx}_{self.global_rank}", stage=stage)

        images_pred = preds["images_pred"]
        images_gt = batch['diffusion_images'] * 0.5 + 0.5
        images_gt = rearrange(images_gt, "b m c h w -> (b m) c h w")
        images_pred = rearrange(images_pred, "b m c h w -> (b m) c h w")

        with autocast(dtype=torch.float32):
            psnr = self.psnr(images_gt.float(), images_pred.float())
            ssim = self.ssim(images_gt.float(), images_pred.float())
            lpips = self.lpips(images_gt.float(), images_pred.float())
        self.log(f"{stage}_psnr", psnr, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_ssim", ssim, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_lpips", lpips, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        # save cond

        return image_fp

    @rank_zero_only
    def save_cond_image(self, cond, stage, batch_idx):
        cond_image = rearrange(cond, "b m c h w -> b c h (m w)")
        grid = torchvision.utils.make_grid(cond_image, nrow=2).to(torch.float32)
        self.log_image(
            tag="{}_cond_images/{}".format(stage, batch_idx),
            image_tensor=grid,
        )

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "val")

    def test_step(self, batch, batch_idx, dataloader_idx):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "test")

    @torch.no_grad()
    def _encode_cond(self, image, do_classifier_free_guidance=False):
        device, dtype = image.device, image.dtype
        image = image.to(torch.float32)
        image = _resize_with_antialiasing(image, (224, 224))
        image = (image + 1.0) / 2.0
        # Normalize the image with for CLIP input
        image = self.feature_extractor(
            images=image,
            do_normalize=True,
            do_center_crop=False,
            do_resize=False,
            do_rescale=False,
            return_tensors="pt",
        ).pixel_values
        
        image = image.to(device).to(dtype)
        image_embeddings = self.image_encoder(image).image_embeds
        image_embeddings = image_embeddings.unsqueeze(1) # b x 1 x 768

        if do_classifier_free_guidance:
            negative_image_embeddings = torch.zeros_like(image_embeddings)
            image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
        return image_embeddings
        
    @torch.no_grad()
    def tensor_to_vae_latent(self, t, vae, needs_upcasting = True, micro_bs = 1):
        ori_shape_len = len(t.shape)
        model_dtype = next(vae.parameters()).dtype
        if ori_shape_len == 4:
            t = t.unsqueeze(1)
        video_length = t.shape[1]
        t = rearrange(t, "b f c h w -> (b f) c h w")
        if needs_upcasting:
            vae.to(dtype=torch.float32)
            t = t.to(torch.float32)
        # latents = vae.encode(t).latent_dist.sample()
        chunk_outs = []
        t_list = t.chunk(micro_bs, dim=0)
        for t_chunk in t_list:
            chunk_outs.append(vae.encode(t_chunk).latent_dist.sample())
        latents = torch.cat(chunk_outs, dim=0)

        if needs_upcasting:
            vae.to(dtype=model_dtype)
            latents = latents.to(model_dtype)
        latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
        if ori_shape_len == 4:
            latents = latents.squeeze(1)
        latents = latents * vae.config.scaling_factor
        return latents
    
    @torch.no_grad()
    def _generate_images(self, batch):
        diffusion_images = batch["diffusion_images"]
        timesteps = self.scheduler.timesteps
        t = timesteps[-1]
        results = self.recon_model(diffusion_images, batch, t)
        return results

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, preds, gt, prompt, batch_idx, stage="validation"):
        images_pred = preds["images_pred"]
        images = gt["diffusion_images"] * 0.5 + 0.5
        save_dir = self.save_dir
        if self.log_image is not None:
            _images = rearrange(images, "b m c h w -> 1 c (b h) (m w)")
            _images_pred = rearrange(images_pred, "b m c h w ->1 c (b h) (m w)")
            _full_image = torch.concat([_images, _images_pred], axis=2) 
            grid = torchvision.utils.make_grid(_full_image, nrow=2)
            self.log_image(
                tag="{}_images/{}".format(stage, batch_idx),
                image_tensor=grid,
            )
        images = rearrange(images, "b m c h w -> (b h) (m w) c")
        images_pred = rearrange(images_pred, "b m c h w -> (b h) (m w) c")
        full_image = torch.concat([images, images_pred], axis=0) 
        full_image = (full_image * 255).cpu().numpy().astype(np.uint8)
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))

        im = Image.fromarray(full_image)
        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        im.save(im_fp)

        depths_pred = preds["depths_pred"]
        depths = gt["depths"]

        depths_pred = (depths_pred - depths_pred.min()) / (depths_pred.max() - depths_pred.min())
        depths = (depths - depths.min()) / (depths.max() - depths.min())
        depths = rearrange(depths, "b m c h w ->  1 c (b h) (m w)")
        depths_pred = rearrange(depths_pred, "b m c h w -> 1 c (b h) (m w)")
        delta_depth = torch.abs(depths - depths_pred)
        full_depth = torch.concat([depths, depths_pred, delta_depth], axis=2)
        
        if self.log_image is not None:
            grid = torchvision.utils.make_grid(full_depth, nrow=2)
            self.log_image(
                tag="{}_depths/{}".format(stage, batch_idx),
                image_tensor=grid,
            )
        torchvision.utils.save_image(full_depth, im_fp.replace(".png", "_depth.png"))
        return im_fp

    def tensorboard_log_image(self, tag: str, image_tensor):
        self.logger.experiment.add_image(
            tag,
            image_tensor,
            self.trainer.global_step,
        )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )

