import os
import json
import random
import yaml
from einops import rearrange
import cv2
import csv
import time
import glob
import numpy as np
import PIL
import PIL.Image
from collections import defaultdict
from typing import Any, Dict, List, Tuple
import torch
from torch import nn
from transformers import Blip2QFormerModel, PretrainedConfig, Blip2QFormerConfig
from peft import LoraConfig, PeftModel, get_peft_model

import llava
from src.VILA.llava.constants import MEDIA_TOKENS
from src.VILA.llava.mm_utils import process_image, process_images
from src.VILA.llava.train.train import find_all_linear_names

from diffusers import (
    AutoencoderKLCogVideoX,
    CogVideoXDPMScheduler,
    CogVideoXImageToVideoPipeline,
    CogVideoXTransformer3DModel,
)

class CogVideoModel(nn.Module):
    def __init__(self, cogvideo_weight):
        super().__init__()
        self.transformer = CogVideoXTransformer3DModel.from_pretrained(cogvideo_weight, subfolder="transformer", torch_dtype=torch.bfloat16)
        self.vae = AutoencoderKLCogVideoX.from_pretrained(cogvideo_weight, subfolder="vae", torch_dtype=torch.bfloat16)
        self.scheduler = CogVideoXDPMScheduler.from_pretrained(cogvideo_weight, subfolder="scheduler")
        self.transformer_config = self.transformer.config

        self.vae.enable_slicing()
        self.vae.enable_tiling()
    
    @torch.no_grad()
    def encode_video(self, video: torch.Tensor) -> torch.Tensor:
        """Encode video frames using VAE
        Args:
            video: Input video tensor of shape [B, C, F, H, W] in range [-1, 1]
        Returns:
            Encoded latent tensor
        """
        self.vae.to("cuda")
        video = video.to(self.vae.device, dtype=self.vae.dtype)
        latent_dist = self.vae.encode(video).latent_dist
        latent = latent_dist.sample() * self.vae.config.scaling_factor
        self.vae.to("cpu")
        return latent

    def compute_loss(self, batch) -> torch.Tensor:
        prompt_embedding = batch["prompt_embedding"]
        videos = batch["videos"]  # [B, C, F, H, W] in range [-1, 1]
        
        # Encode videos to latent space
        with torch.no_grad():
            latent = self.encode_video(videos)

        # Shape of prompt_embedding: [B, seq_len, hidden_size]
        # Shape of latent: [B, C, F, H, W]

        patch_size_t = self.transformer_config.patch_size_t
        if patch_size_t is not None:
            ncopy = latent.shape[2] % patch_size_t
            # Copy the first frame ncopy times to match patch_size_t
            first_frame = latent[:, :, :1, :, :]  # Get first frame [B, C, 1, H, W]
            latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
            assert latent.shape[2] % patch_size_t == 0

        batch_size, num_channels, num_frames, height, width = latent.shape

        # Get prompt embeddings
        _, seq_len, _ = prompt_embedding.shape
        prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)

        # Sample a random timestep for each sample
        timesteps = torch.randint(
            0, self.scheduler.config.num_train_timesteps, (batch_size,), device=self.transformer.device
        )
        timesteps = timesteps.long()

        # Add noise to latent
        latent = latent.permute(0, 2, 1, 3, 4)  # from [B, C, F, H, W] to [B, F, C, H, W]
        noise = torch.randn_like(latent)
        latent_added_noise = self.scheduler.add_noise(latent, noise, timesteps)

        # Prepare rotary embeds
        vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1)
        transformer_config = self.transformer_config
        rotary_emb = (
            self.prepare_rotary_positional_embeddings(
                height=height * vae_scale_factor_spatial,
                width=width * vae_scale_factor_spatial,
                num_frames=num_frames,
                transformer_config=transformer_config,
                vae_scale_factor_spatial=vae_scale_factor_spatial,
                device=self.transformer.device,
            )
            if transformer_config.use_rotary_positional_embeddings
            else None
        )

        # Predict noise
        predicted_noise = self.transformer(
            hidden_states=latent_added_noise,
            encoder_hidden_states=prompt_embedding,
            timestep=timesteps,
            image_rotary_emb=rotary_emb,
            return_dict=False,
        )[0]

        # Denoise
        latent_pred = self.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps)

        alphas_cumprod = self.scheduler.alphas_cumprod[timesteps]
        weights = 1 / (1 - alphas_cumprod)
        while len(weights.shape) < len(latent_pred.shape):
            weights = weights.unsqueeze(-1)

        loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
        loss = loss.mean()

        return loss

class AnimeShooterGen(nn.Module):
    def __init__(
            self,
            image_prefix_length,
            max_text_seq_length,
            backbone_weight,
            cogvideo_weight,
            add_reference_image,
            last_clip_frame_num,
            uncond_prob,
            qformer_num_hidden_layers,
            **kwargs
    ):
        super().__init__()
        
        self.image_prefix_length = image_prefix_length
        self.max_text_seq_length = max_text_seq_length
        self.add_reference_image = add_reference_image
        self.last_clip_frame_num = last_clip_frame_num
        self.uncond_prob = uncond_prob
        self.empty_image = PIL.Image.new('RGB', (256, 256), color='white')
        print(f"\nAdd reference image: {self.add_reference_image}, last_clip_frame_num: {self.last_clip_frame_num}, uncond_prob: {self.uncond_prob}\n")
        
        # Initialize NVILA components
        print("Loading NVILA model and config...")
        self.backbone_model = llava.load(backbone_weight, device_map=None, device="cuda", torch_dtype=torch.bfloat16, attn_implementation='flash_attention_2')
        self._add_special_tokens()
        self.backbone_config = PretrainedConfig.from_json_file(os.path.join(backbone_weight, "config.json"))
        self.image_prefix = nn.Linear(self.backbone_config.hidden_size, self.image_prefix_length, bias=False)
        self.image_prefix = self.image_prefix.to(torch.bfloat16)

        # Initialize CogVideo components
        print("Loading CogVideoX and config...")
        self.cogvideo = CogVideoModel(cogvideo_weight)
        self.cogvideo_text_encoder_config = PretrainedConfig.from_json_file(os.path.join(cogvideo_weight, "text_encoder/config.json"))
        
        # Initialize adapter components
        print(f"Initializing QFormer with {qformer_num_hidden_layers} hidden layers...")
        self.qformer_config = Blip2QFormerConfig(num_hidden_layers=qformer_num_hidden_layers)
        self.diffusion_qformer = Blip2QFormerModel(self.qformer_config)
        self.diffusion_qformer = self.diffusion_qformer.to(torch.bfloat16)
        self.diffusion_qformer_proj = nn.Linear(self.backbone_config.hidden_size, self.qformer_config.encoder_hidden_size)
        self.diffusion_qformer_proj = self.diffusion_qformer_proj.to(torch.bfloat16)
        self.diffusion_query_tokens = nn.Parameter(torch.zeros(self.max_text_seq_length, self.qformer_config.hidden_size))
        self.diffusion_query_tokens = nn.Parameter(self.diffusion_query_tokens.to(torch.bfloat16))
        self.diffusion_proj = nn.Linear(self.qformer_config.hidden_size, self.cogvideo_text_encoder_config.d_model)
        self.diffusion_proj = self.diffusion_proj.to(torch.bfloat16)
    
    def _add_special_tokens(self):
        self.backbone_model.tokenizer.add_special_tokens({
            "additional_special_tokens": [
                "<|im_start|>",
                "<|im_end|>",
                "<vila/sentinel>",
                "<image>",
                "<vila/video>",
                "[IMG_P]"
            ]
        })

        self.backbone_model.llm.model.resize_token_embeddings(len(self.backbone_model.tokenizer))

    def get_diffusion_conditioning(self, prompts: list[list[str]], images: list[list[PIL.Image.Image]], reference_images: list[PIL.Image.Image]):
        """
        Args:
            prompts: captions (each sample may contain different number of captions); first level list is for batch, second level list is for multiclip.
            images: last frames of each clip
            reference_images: reference image for each sample
        Returns:
            diffusion_conditioning: conditioning for diffusion model; shape: (sum(clip_num), 226, 4096)
            clip_num: number of clips for each sample; shape: (batch_size,)
        """

        # add vocabulary
        image_prefix_token_id = self.backbone_model.tokenizer.convert_tokens_to_ids("[IMG_P]")
        default_image_token_id = self.backbone_model.tokenizer.convert_tokens_to_ids("<image>")
        
        # process multimodal input
        flattened_prompts = []
        flattened_images = []
        clip_num = []

        for batch_prompts, batch_images, batch_reference_image in zip(prompts, images, reference_images):
            # add reference image
            combined_image = [batch_reference_image] + batch_images if self.add_reference_image else batch_images
            combined_prompt = "<image>" if self.add_reference_image else ""
            expected_image_count = 1 if self.add_reference_image else 0
            
            # add prompts and last clip images
            for p in batch_prompts[:-1]:
                combined_prompt += p + "[IMG_P]" * self.image_prefix_length + "<image>" * self.last_clip_frame_num
                expected_image_count += self.last_clip_frame_num

            combined_prompt += batch_prompts[-1] + "[IMG_P]" * self.image_prefix_length

            flattened_prompts.append(combined_prompt)
            flattened_images.extend(combined_image)
            clip_num.append(len(batch_prompts))
            assert expected_image_count == len(combined_image), f"Mismatch between <image> tokens ({expected_image_count}) and actual images ({len(combined_image)}). Prompt: {combined_prompt}."

        print(f"\nprompts: {flattened_prompts}")
        print(f"images num: {len(flattened_images)}")
        print(f"clip_num: {clip_num}\n")

        if flattened_images:
            with torch.no_grad():
                images = process_images(flattened_images, self.backbone_model.vision_tower.image_processor, self.backbone_model.config)
            media = {"image": [image for image in images]}
        else:
            images = []
            media = {}
        media_config = defaultdict(dict)

        # tokenize
        tokenized = self.backbone_model.tokenizer(
            flattened_prompts, 
            padding=True, 
            return_tensors="pt"
        ).to(self.backbone_model.device)
        input_ids = tokenized.input_ids
        attention_mask = tokenized.attention_mask.bool()
        labels = input_ids.clone()
        input_ids[input_ids.eq(image_prefix_token_id)] = 0
        
        # fuse image embeddings
        inputs_embeds, labels, attention_mask = self.backbone_model._embed(
            input_ids, media, media_config, 
            labels=labels, 
            attention_mask=attention_mask
        )

        # modify image prefix processing
        bs, seq_len = labels.shape
        labels = labels.reshape(-1)  # bs * seq_len -> (bs * seq_len)
        image_prefix_mask = labels.eq(image_prefix_token_id)
        inputs_embeds = inputs_embeds.reshape(bs * seq_len, -1)  # bs * seq_len * 3584 -> (bs * seq_len) * 3584
        image_num = image_prefix_mask.sum().item() / self.image_prefix_length
        assert int(image_num) == image_num
        image_prefix_embeddings = self.image_prefix.weight.repeat(int(image_num), 1).to(inputs_embeds.dtype)
        inputs_embeds[image_prefix_mask] = image_prefix_embeddings
        inputs_embeds = inputs_embeds.reshape(bs, seq_len, -1)

        # llm forward
        outputs = self.backbone_model.llm(
            inputs_embeds=inputs_embeds, 
            attention_mask=attention_mask,
            output_hidden_states=True)
        output_hidden_states = outputs.hidden_states[-1]
        output_hidden_states = output_hidden_states.reshape(bs * seq_len, -1)
        image_outputs_embeds = output_hidden_states[image_prefix_mask]
        
        # q former processing
        diffusion_conditioning = image_outputs_embeds.view(-1, self.image_prefix_length, self.backbone_config.hidden_size)
        diffusion_conditioning = self.diffusion_qformer_proj(diffusion_conditioning)
        diffusion_query_tokens = self.diffusion_query_tokens.expand(diffusion_conditioning.shape[0], -1, -1)
        diffusion_conditioning = self.diffusion_qformer(
            query_embeds=diffusion_query_tokens,
            encoder_hidden_states=diffusion_conditioning,
        )[0]
        diffusion_conditioning = self.diffusion_proj(diffusion_conditioning)
        
        assert diffusion_conditioning.shape[0] == sum(clip_num), f"Mismatch between diffusion conditioning dimension ({diffusion_conditioning.shape[0]}) and total number of clips ({sum(clip_num)})."
        return diffusion_conditioning, clip_num
    
    def forward(self, captions, videos, images, reference_images):
        # classifier-free guidance
        random_num = torch.rand(1, device=videos.device)
        condition_mask = rearrange(random_num < self.uncond_prob, "n -> n 1 1")
        print(f"\ncondition_mask: {condition_mask}\n")
        if condition_mask:
            for bs in range(len(captions)):
                captions[bs] = [""] * len(captions[bs])
                images[bs] = [self.empty_image] * len(images[bs])
                reference_images[bs] = self.empty_image

        # exclude last clip images
        if self.last_clip_frame_num > 0:
            images = [imgs[:-self.last_clip_frame_num] for imgs in images]
        
        # calculate loss
        with torch.autocast("cuda", dtype=torch.bfloat16):
            conditioning, _ = self.get_diffusion_conditioning(captions, images, reference_images)
            batch = {
                "prompt_embedding": conditioning,
                "videos": videos.to(self.cogvideo.transformer.device),
            }
            loss = self.cogvideo.compute_loss(batch)

        return {"loss": loss}
    
    def adding_LLM_lora(self, peft_config):
        if peft_config['llm']['enabled']:
            # Configure LoRA for LLM
            print("Adding LoRA adapters to LLM...")
            lora_config_llm = LoraConfig(
                r=peft_config['llm']['r'],
                lora_alpha=peft_config['llm']['lora_alpha'],
                target_modules=find_all_linear_names(self.backbone_model, lora_llm=True, lora_vt=False),
                lora_dropout=peft_config['llm']['lora_dropout'],
                task_type=peft_config['llm']['task_type'],
            )
            self.backbone_model = get_peft_model(self.backbone_model, lora_config_llm)
    
    def prepare_trainable_parameters_cogvideo_lora(self, peft_config):
        print("Initializing trainable parameters...")
        for param in self.parameters():
            param.requires_grad = False
        
        if peft_config['cogvideo']['enabled']:
            # Configure LoRA for CogVideo
            print("Adding LoRA adapters to CogVideo...")

            transformer_lora_config = LoraConfig(
                r=peft_config['cogvideo']['r'],
                lora_alpha=peft_config['cogvideo']['lora_alpha'],
                init_lora_weights=peft_config['cogvideo']['init_lora_weights'],
                target_modules=peft_config['cogvideo']['target_modules'],
            )        
            self.cogvideo.transformer.add_adapter(transformer_lora_config)

        self.backbone_model.llm.gradient_checkpointing_enable()
        self.cogvideo.transformer.enable_gradient_checkpointing()

        from src.utils.attention_processor import CogVideoXXFormersAttnProcessor2_0
        from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
        self.cogvideo.transformer.set_attn_processor(CogVideoXXFormersAttnProcessor2_0(MemoryEfficientAttentionFlashAttentionOp))
    
    @torch.no_grad()
    def evaluation(self, prompts: list[list[str]], images: list[list[PIL.Image.Image]], reference_images: list[PIL.Image.Image]):
        empty_prompts = [[""] * len(p) for p in prompts]
        empty_images = [[self.empty_image] * len(i) for i in images]
        empty_reference_images = [self.empty_image] * len(reference_images)

        conditions_total, clip_num = self.get_diffusion_conditioning(prompts, images, reference_images)
        negative_conditions_total, _ = self.get_diffusion_conditioning(empty_prompts, empty_images, empty_reference_images)

        # Only keep the last conditions for each clip
        last_conditions, negative_last_conditions = [], []
        start_idx = 0
        for num_clips in clip_num:
            last_idx = start_idx + num_clips - 1
            last_conditions.append(conditions_total[last_idx])
            negative_last_conditions.append(negative_conditions_total[last_idx])
            start_idx += num_clips
        
        return torch.stack(last_conditions), torch.stack(negative_last_conditions)
        