from typing import Any, List, Optional, Tuple, Union

import hydra
import pytorch_lightning
import torch
import torch.nn.functional as F
from datamodules.video_data_api import VideoData
from model_pipeline import MoViePipeline
from omegaconf import DictConfig
from torch import Tensor
from torchmetrics.functional import peak_signal_noise_ratio
from utils.hydra_tools import OmegaConf  # ! Needed for resolvets to take effect

import lpips, clip
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection 
from torchvision import transforms
import torch.nn as nn
from PIL import Image
import os
import torchvision
import random

def multiply_255(x):
    return x * 255.0

def divide_255(x):
    return x * 0.00392156862745098

class MoVieModule(pytorch_lightning.LightningModule):
    """Encapsulates model + loss + optimizer + metrics in a PL module"""

    def __init__(self, model: MoViePipeline, CLIP_text_model , cfg: DictConfig) -> None:
        super().__init__()
        self.model = model
        self.CLIP_text_model = CLIP_text_model
        self.cfg_optim = getattr(cfg, "optim", None)
        self.cfg_scheduler = getattr(cfg, "scheduler", None)
        self.lr_annealing_frequency = getattr(
            getattr(cfg, "training_loop", None), "lr_annealing_frequency", None
        )
        self._distortion_loss_scaling = getattr(
            getattr(cfg, "training_loop", None), "distortion_lambda", None
        )
        if self._distortion_loss_scaling is not None:
            self._distortion_loss_scaling = self._distortion_loss_scaling * (255**2)
        
        # lpips loss
        lpips_coefficient = 1.0
        self.k_P = lpips_coefficient
        self.__dict__["loss_fn_alex"] = lpips.LPIPS(net='alex').eval()
        self.loss_fn_alex.net.requires_grad_(False)
        
        # LM Multi Modal loss
        clip_model_name = "openai/clip-vit-base-patch32"
        jit_coefficient = 0.005
        self.k_J = jit_coefficient # Multimodal loss
        self.beta = 40.0 

        self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_model_name).eval()
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).eval()

        self.text_encoder.requires_grad_(False)
        self.image_encoder.requires_grad_(False)
        # mimic the image transform function of CLIP
        self.transform_for_clip = transforms.Compose([
            multiply_255, 
            transforms.Resize(224), # do_resize
            transforms.CenterCrop(224), # do_center_crop
            divide_255, # do_rescale
            transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) # do_normalize
            ])

        self.logit_scale = nn.Parameter(torch.ones([]) * 2.6592) # "openai/clip-vit-base-patch32" 
        self.lmbda = [0.00045, 0.0018, 0.0067, 0.0200] 
    
    # code from modeling_clip.py in huggingface
    def contrastive_loss(self, logits: torch.Tensor) -> torch.Tensor:
        return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

    def get_warmup_weight_by_step(global_step, final_weight, warmup_steps):
        return final_weight * min(1.0, global_step / warmup_steps)
    
    def clip_loss(self, similarity: torch.Tensor) -> torch.Tensor:
        caption_loss = self.contrastive_loss(similarity)
        image_loss = self.contrastive_loss(similarity.t())
        return (caption_loss + image_loss) / 2.0

    def get_joint_image_text_loss(self, output, target, text_tokens, attention_mask) :
            # output / target: (B, T, C, H, W)
        B, T, C, H, W = output.shape
        device = output.device

        output = output.reshape(B * T, C, H, W)
        target = target.reshape(B * T, C, H, W)
        self.image_encoder = self.image_encoder.to(device)
        self.text_encoder = self.text_encoder.to(device)

        # Apply transform_for_clip on each frame individually
        compressed_image = torch.stack([self.transform_for_clip(f) for f in output]).to(device)
        target_image = torch.stack([self.transform_for_clip(f) for f in target]).to(device)

        compressed_image_embeddings = self.image_encoder(compressed_image).image_embeds
        target_image_embeddings = self.image_encoder(target_image).image_embeds

        # loss Attempt to normalize for calculations
        compressed_image_embeddings = compressed_image_embeddings / compressed_image_embeddings.norm(p=2, dim=-1, keepdim=True)
        target_image_embeddings = target_image_embeddings / target_image_embeddings.norm(p=2, dim=-1, keepdim=True)

        text_embeddings = self.text_encoder(input_ids = text_tokens, attention_mask = attention_mask).text_embeds
        text_embeddings = text_embeddings / text_embeddings.norm(p=2, dim=-1, keepdim=True)
        #  text_embeddings: [1, 512]
        text_embeddings = text_embeddings.repeat(T, 1)  # [T, 512]  [7, 512]

        logit_scale = self.logit_scale.exp()
        logits_per_text = torch.matmul(text_embeddings, compressed_image_embeddings.t()) * logit_scale

        l2_loss = torch.norm(compressed_image_embeddings - target_image_embeddings, p=2)
        cliploss = self.clip_loss(logits_per_text)
        joint_image_text_loss = cliploss + self.beta * l2_loss

        return joint_image_text_loss, cliploss, l2_loss
    
    def forward(self, video: VideoData, tokens, attention_masks, s) -> Tensor:
        device = next(self.model.parameters()).device
        tokens = tokens.to(device)
        attention_masks = attention_masks.to(device)
    
        text_embeddings = self.CLIP_text_model(input_ids = tokens, attention_mask = attention_masks).last_hidden_state

        return self.model(video, text_embeddings, s)

    def rate_loss(self, video: VideoData, rate_args, per_frame: bool = True) -> Tensor:
        # average over batch dim but not over frame within the video
        B, T, _, H, W = video.shape
        num_pixels = B * H * W if per_frame else B * H * W * T
        total_bits = self.model.compute_rate(*rate_args, per_frame=per_frame)
        return total_bits / num_pixels

    def training_step(
        self, batch, batch_idx: int, optimizer_idx: Optional[int] = None
    ) -> Tensor:
        video, tokens, attention_masks = batch["video"], batch["tokens"], batch["attention_masks"]
        s = random.randint(0, 3) # choose random level from [0, levels-1]
        # train all
        recon_frames, rate_args = self(video, tokens, attention_masks, s)
        dist_loss_frame = F.mse_loss(
            recon_frames, video.video_tensor, reduction="none"
        ).mean(dim=(0, 2, 3, 4))
        dist_loss_frame = dist_loss_frame* (255**2)
        combined_loss_frame = dist_loss_frame.clone()

        if self.k_J > 0.0:
            joint_image_text_loss, cliploss, l2_loss = self.get_joint_image_text_loss(recon_frames, video.video_tensor, tokens, attention_masks)
            combined_loss_frame +=  self.k_J * joint_image_text_loss
        if self.k_P > 0.0 :
            device = next(self.model.parameters()).device
            self.loss_fn_alex = self.loss_fn_alex.to(device)
            B, T, C, H, W = recon_frames.shape
            output = (recon_frames.reshape(B * T, C, H, W) * 2) - 1
            input = (video.video_tensor.reshape(B * T, C, H, W) * 2) - 1

            output = torch.clamp(output, -1.0, 1.0).float().detach()
            input = torch.clamp(input, -1.0, 1.0).float().detach()
            with torch.no_grad():
                loss_vals = self.loss_fn_alex(output, input)
            perceptual_loss = loss_vals.mean()

            combined_loss_frame += self.k_P * perceptual_loss
            
        rate_loss_frame = self.rate_loss(video, rate_args, per_frame=True)  # [T]
        combined_loss_frame =  combined_loss_frame * self.lmbda[s] + rate_loss_frame

        # combined_loss_weighted = (combined_loss_frame * weights).mean()

        with torch.no_grad():
            psnr = peak_signal_noise_ratio(
                recon_frames, video.video_tensor, data_range=1.0, dim=(2, 3, 4)
            ).item()
            psnr_int = peak_signal_noise_ratio(
                (recon_frames * 255 + 0.5).to(torch.uint8),
                (video.video_tensor * 255 + 0.5).to(torch.uint8),
                data_range=255,
                dim=(2, 3, 4),
            ).item()
        self.log_dict(
            {
                "lmbda:": self.lmbda[s],
                "distortion_loss": dist_loss_frame.mean() * self.lmbda[s],
                "rate_loss": rate_loss_frame.mean(),
                "perceptual_loss": perceptual_loss*self.k_P,
                "joint_image_text_loss": joint_image_text_loss*self.k_J,
                "combined_loss": combined_loss_frame.mean(),
                # "combined_loss_weighted": combined_loss_weighted,
                "PSNR": psnr,
                # "PSNR_int": psnr_int,
            },
            sync_dist=True,
        )
        return combined_loss_frame.mean()

    def validation_step(self, batch, batch_idx) -> None:
        video, tokens, attention_masks = batch["video"], batch["tokens"], batch["attention_masks"]
        s = 1
        recon_frames, rate_args = self(video, tokens, attention_masks, s)
        dist_loss_frame = F.mse_loss(
            recon_frames, video.video_tensor, reduction="none"
        ).mean(dim=(0, 2, 3, 4))
        dist_loss_frame = dist_loss_frame * (255**2)
        combined_loss_frame = dist_loss_frame.clone()

        if self.k_J > 0.0:
            joint_image_text_loss, cliploss, l2_loss = self.get_joint_image_text_loss(recon_frames, video.video_tensor, tokens, attention_masks)
            combined_loss_frame +=  self.k_J * joint_image_text_loss
        if self.k_P > 0.0 :
            device = next(self.model.parameters()).device

            self.loss_fn_alex = self.loss_fn_alex.to(device)
            B, T, C, H, W = recon_frames.shape
            output = (recon_frames.reshape(B * T, C, H, W) * 2) - 1
            output_clamp = torch.clamp(output, -1.0, 1.0)

            input = (video.video_tensor.reshape(B * T, C, H, W) * 2) - 1
            perceptual_loss = self.loss_fn_alex(output_clamp, input).mean()
            combined_loss_frame += self.k_P * perceptual_loss
        rate_loss_frame = self.rate_loss(video, rate_args, per_frame=True)  # [T]
        combined_loss_frame =  combined_loss_frame * self.lmbda[s] + rate_loss_frame

        self.log_dict(
            {
                "val_distortion_loss": dist_loss_frame.mean().item() * self.lmbda[s],
                "lmbda:": self.lmbda[s],
                "perceptual_loss": perceptual_loss * self.k_P,
                "joint_image_text_loss": self.k_J * joint_image_text_loss,
                "val_combined_loss_weighted": combined_loss_frame.mean().item(),
                "val_rate_loss": rate_loss_frame.mean().item(),
                "val_PSNR": peak_signal_noise_ratio(
                    recon_frames, video.video_tensor, data_range=1.0, dim=(2, 3, 4)
                ).item(),
            },
            sync_dist=True,
        )

    def test_step(
        self,
        batch,
        batch_idx: int,
        run_fwd: bool = True,
    ):
        video, tokens, attention_masks = batch["video"], batch["tokens"], batch["attention_masks"]
        B, T, C, H, W = video.shape
        assert B == 1, "Metrics calculation only supported for batch size 1."
        assert self.training is False

        with torch.no_grad():
            video_ref_uint8 = (video.video_tensor[0] * 255 + 0.5).to(torch.uint8)
            self.model._code_to_strings = False
            s = 3
            forward_pass = self(video, tokens, attention_masks, s)
            recon, rate_args = forward_pass[-2], forward_pass[-1][0]
            recon_clamped = torch.clamp(recon[0], 0.0, 1.0)
            recon_uint8 = (recon_clamped * 255 + 0.5).to(torch.uint8)
            psnr = peak_signal_noise_ratio(
                recon_uint8, video_ref_uint8, data_range=255
            ).item()
            bpp = self.model.compute_rate(rate_args).item() / (H * W * T)
            self.log_dict(
                {
                    f"cts_psnr": psnr,
                    f"cts_rate": bpp,
                },
                sync_dist=True,
            )
        return psnr, bpp

    def configure_optimizers(
        self,
    ) -> Union[
        torch.optim.Optimizer,  # Single optimizer
        Tuple[torch.optim.Optimizer, torch.optim.Optimizer],  # Tuple or list of optim
        List[torch.optim.Optimizer],
        dict,  # "optimizer" key, and (optionally) an "lr_scheduler"
        Any,  # 2 lists: first with optimizers, second has LR schedulers; or Tuple[Dict]
    ]:
        # PL allows return type to be tuple/list/dict/two lists/tuple of dicts/None
        model_params = (
            p
            for n, p in self.named_parameters()
            if not n.endswith(".quantiles") and p.requires_grad
        )

        base_optim = hydra.utils.instantiate(self.cfg_optim, params=model_params)
        base_scheduler = hydra.utils.instantiate(
            self.cfg_scheduler, optimizer=base_optim
        )
        scheduler = {
            "scheduler": base_scheduler,
            "interval": "step",
            "frequency": self.lr_annealing_frequency,
        }
        return {"optimizer": base_optim, "lr_scheduler": scheduler}
