import os
import os
import torch
import PIL
import gc
import cv2
import json
import pandas as pd
import random
from omegaconf import OmegaConf
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
import argparse
from typing import Optional, Literal, Union
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
    CogVideoXPipeline,
    AutoencoderKLCogVideoX,
    CogVideoXDPMScheduler,
    CogVideoXImageToVideoPipeline,
    CogVideoXTransformer3DModel,
)
from diffusers.utils import export_to_video

from src.model import AnimeShooterGen
from src.utils.videoreader_pyav import VideoReaderAV
from src.utils.process_reference import process_reference_image

@torch.no_grad()
def generate_video(
    prompt_embeds: torch.Tensor,
    negative_prompt_embeds: torch.Tensor,
    guidance_scale: float,
    pipe: Union[CogVideoXPipeline],
    output_path: str = "output.mp4",
    num_frames: int = 49,
    width: Optional[int] = 720,
    height: Optional[int] = 480,
    num_inference_steps: int = 50,
    num_videos_per_prompt: int = 1,
    fps: int = 16,
    seed: int = 42
):
    """
    Generate video using CogVideoX with pre-computed prompt embeddings
    
    Args:
        prompt_embeds (torch.Tensor): Pre-computed prompt embeddings
        pipe: CogVideoX pipeline
        output_path (str): Path to save output video
        num_frames (int): Number of frames to generate
        width (int, optional): Output video width
        height (int, optional): Output video height
        num_inference_steps (int): Number of denoising steps
        guidance_scale (float): Classifier-free guidance scale
        num_videos_per_prompt (int): Number of videos to generate per prompt
        fps (int): Output video FPS
    """

    torch.manual_seed(seed)

    with torch.autocast("cuda", dtype=torch.bfloat16):
        # Generate video
        output = pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            num_frames=num_frames,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            num_videos_per_prompt=num_videos_per_prompt,
            # generator=torch.Generator(device="cuda").manual_seed(seed)
        )
    
    # Save video
    video_frames = output.frames[0]
    export_to_video(video_frames, output_path, fps=fps)
    
    return output

args = argparse.ArgumentParser()
args.add_argument("--config", type=str, default="src/config/inference_config.yaml")
args.add_argument("--video_id", type=str, default="1dCd6hCRoaQ")
args = args.parse_args()

config = OmegaConf.load(args.config)
config['model']['cogvideo_lora_weight'] = config['model']['cogvideo_lora_weight'].replace(config['video_id'], args.video_id)
config['video_id'] = args.video_id
video_kwargs = config['video_kwargs']

""" load model & pipeline"""
model = AnimeShooterGen(**config['model'])
model.adding_LLM_lora(config['peft'])
model.prepare_trainable_parameters_cogvideo_lora(config['peft'])
# load pretrained weight & lora weight
model.load_state_dict(torch.load(config['model']['pretrained_weight'], map_location='cpu'), strict=False)
print(f"pretrained weight loaded from {config['model']['pretrained_weight']}")
model.load_state_dict(torch.load(config['model']['cogvideo_lora_weight'], map_location='cpu'), strict=False)
print(f"lora weight loaded from {config['model']['cogvideo_lora_weight']}")
model = model.cuda().to(dtype=torch.bfloat16)
model.eval()
for param in model.parameters():
    param.requires_grad = False

# load pipeline
pipe = CogVideoXPipeline(
    vae=model.cogvideo.vae,
    text_encoder=T5EncoderModel.from_pretrained(config['model']['cogvideo_weight'], subfolder="text_encoder", torch_dtype=torch.bfloat16),
    tokenizer=AutoTokenizer.from_pretrained(config['model']['cogvideo_weight'], subfolder="tokenizer"),
    scheduler=model.cogvideo.scheduler,
    transformer=model.cogvideo.transformer
)
del pipe.text_encoder
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe = pipe.to("cuda")
print("CogVideoX pipeline loaded...")

output_dir = os.path.join(config['output_dir'], args.video_id)
os.makedirs(output_dir, exist_ok=True)

""" generate videos """ 
captions = json.load(open(f"demos/{args.video_id}.json"))
reference_image_paths = [f"datasets/references/{args.video_id}_ref.png"] * len(captions)
reference_images = [PIL.Image.open(path) for path in reference_image_paths]

batch_size, num_clips = len(captions), len(captions[0])
images = [[] for _ in range(batch_size)]
empty_image = PIL.Image.new('RGB', (256, 256), color='white')

# generate videos
with torch.no_grad() and torch.autocast("cuda", dtype=torch.bfloat16):
    for tmp_clip_num in range(num_clips):
        tmp_captions = [caption[:tmp_clip_num+1] for caption in captions]
        print(f"\ntmp_captions: {tmp_captions}\nimages: {images}\n")

        # get LLM conditioning
        model = model.to("cuda")
        prompt_embeds, negative_prompt_embeds = model.evaluation(tmp_captions, images, reference_images)
        prompt_embeds = prompt_embeds.cpu()
        negative_prompt_embeds = negative_prompt_embeds.cpu()
        # print(f"\nprompt_embeds.shape: {prompt_embeds.shape}\nnegative_prompt_embeds.shape: {negative_prompt_embeds.shape}\n")
        model = model.cpu()

        # Generate and save videos for each sample
        for bs in range(batch_size):
            print(f"\nSample {bs}:")
            print(f"Caption: {tmp_captions[bs]}")

            current_prompt_embeds = prompt_embeds[bs].unsqueeze(0).cuda()
            current_negative_prompt_embeds = negative_prompt_embeds[bs].unsqueeze(0).cuda()

            output_path = os.path.join(output_dir, f"generated_video_{bs}_{tmp_clip_num}.mp4")
            print(f"Generating video to: {output_path}")
            
            generate_video(
                prompt_embeds=current_prompt_embeds,
                negative_prompt_embeds=current_negative_prompt_embeds,
                pipe=pipe,
                output_path=output_path,
                **video_kwargs
            )
            current_prompt_embeds = current_prompt_embeds.cpu()
            current_negative_prompt_embeds = current_negative_prompt_embeds.cpu()

            # save last frame
            video_reader = VideoReaderAV(output_path)
            frames = video_reader.get_batch([video_kwargs['num_frames'] - 1])
            last_frame = frames[0].astype('uint8')
            last_frame_pil = Image.fromarray(last_frame)
            # last_frame_pil.save(os.path.join(output_dir, f"generated_video_{bs}_{tmp_clip_num}_last_frame.png"))
            images[bs].append(last_frame_pil)
    