import argparse
from typing import Any, Dict, List, Literal, Tuple
import pandas as pd
import os
import sys

import torch
from diffusers import (
    CogVideoXPipeline,
    CogVideoXDDIMScheduler,
    CogVideoXDPMScheduler,
    CogVideoXImageToVideoPipeline,
    CogVideoXVideoToVideoPipeline,
)

from diffusers.utils import export_to_video, load_image, load_video

import numpy as np
import random
import cv2
from pathlib import Path
import decord
from torchvision import transforms
from torchvision.transforms.functional import resize

import PIL.Image
from PIL import Image

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(current_dir, '..'))
from models.cogvideox_event import CogVideoXVideoToVideoPipelineEvent_CFG
from training.dataset import VideoDatasetWithResizingEvent
import torch.nn.functional as F

def sample_from_dataset(
    data_root: str,
    num_samples: int = -1,
    random_seed: int = 42,
):
    """Sample from dataset"""

    dataset = VideoDatasetWithResizingEvent(
        data_root=data_root,
        max_num_frames=49,
        height_buckets=[480],
        width_buckets=[720],
        load_tensors=False,
        random_flip=None,
        frame_buckets=[49],
        image_to_video=True,
        eval=True
    )

    
    # Set random seed
    random.seed(random_seed)
    
    # Randomly sample from dataset
    total_samples = len(dataset)
    # num_samples = 2 # debug
    if num_samples == -1:
        # If num_samples is -1, process all samples
        selected_indices = range(total_samples)
    else:
        selected_indices = random.sample(range(total_samples), min(num_samples, total_samples))
    
    samples = {}
    for idx in selected_indices:
        sample = dataset[idx]
        max_number_frames = 49
        # Get data based on dataset.__getitem__ return value
        image = sample["image"]  # Already processed tensor
        video_low = sample["video_low"]
        video_normal = sample["video_normal"]
        video_event = sample["video_event"]
        # video = sample["video"]  # Already processed tensor
        prompt = sample["prompt"]
        seq_name = sample["seq_name"]
        seq_length = video_low.shape[0]
        seq_samples = []
        for i in range(0, seq_length, max_number_frames-1):
            if i + max_number_frames > seq_length:
                start_idx = seq_length - max_number_frames
                first_frame_idx = i - start_idx + 1
            else:
                start_idx = i
                first_frame_idx = 1
            video_low_i = video_low[start_idx:start_idx+max_number_frames]
            video_normal_i = video_normal[start_idx:start_idx+max_number_frames]
            video_event_i = video_event[start_idx:start_idx+max_number_frames]
            image_i = video_normal_i[first_frame_idx].clone()
            seq_samples.append({
                "prompt": prompt,
                "image": image_i,
                "video_low": video_low_i,
                "video_normal": video_normal_i,
                "video_event": video_event_i,
                "height": sample["video_metadata"]["height"],
                "width": sample["video_metadata"]["width"],
                "first_frame_idx": first_frame_idx,
                "seq_length": seq_length
            })

        samples[seq_name] = seq_samples
    
    return samples, dataset.video_transforms

def generate_video(
    model_path: str,
    image_or_video_path: str = "",
    num_inference_steps: int = 50,
    guidance_scale: float = 6.0,
    dtype: torch.dtype = torch.bfloat16,
    seed: int = 42,
    data_root: str = None,
    num_samples: int = -1,
    evaluation_dir: str = "evaluations",
    fps: int = 8,
):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # If dataset parameters are provided, sample from dataset
    samples = None
    if data_root:
        samples, video_transforms = sample_from_dataset(
            data_root=data_root,
            num_samples=num_samples,
            random_seed=seed
        )

    # Load model and data
    pipe = CogVideoXVideoToVideoPipelineEvent_CFG.from_pretrained(model_path, torch_dtype=dtype)
    if not samples:
        image = load_image(image=image_or_video_path)
        height, width = image.height, image.width

    # Set model parameters
    pipe.to(device, dtype=dtype)
    pipe.vae.enable_slicing()
    pipe.vae.enable_tiling()
    pipe.transformer.eval()
    pipe.text_encoder.eval()
    pipe.vae.eval()
    pipe.transformer.gradient_checkpointing = False
    pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    # Generate video
    if samples:
        from tqdm import tqdm
        for i, seq_name in tqdm(enumerate(samples), desc="Samples Num:"):
            first_frame = None
            all_generated_frames = []
            all_video_lows = []
            all_video_events = []
            for j, sample_j in enumerate(samples[seq_name]):
                print(f"Prompt: {sample_j['prompt'][:30]}")
                first_frame_idx = sample_j["first_frame_idx"]

                video_low = sample_j["video_low"].to(device=device, dtype=dtype) # [F, C, H, W]
                if j == 0:
                    all_video_lows.append(video_low)
                else:
                    all_video_lows.append(video_low[first_frame_idx:])
                video_low = video_low[1:]
                video_normal = sample_j["video_normal"].to(device=device, dtype=dtype)
                video_event = sample_j["video_event"].to(device=device, dtype=dtype)
                if j == 0:
                    all_video_events.append(video_event)
                else:
                    all_video_events.append(video_event[first_frame_idx:])
                event_image = video_event[:1].clone()
                image = video_normal[:1].clone() 
                
                # VAE

                video_low = video_low.unsqueeze(0).permute(0, 2, 1, 3, 4)
                video_normal = video_normal.unsqueeze(0).permute(0, 2, 1, 3, 4)
                video_event = video_event.unsqueeze(0).permute(0, 2, 1, 3, 4)

                # illumination map
                illumination_map = video_low.clone()
                illumination_map = illumination_map.max(dim=1, keepdim=True).values

                with torch.no_grad():

                    video_low_latent_dist = pipe.vae.encode(video_low).latent_dist
                    video_low = video_low_latent_dist.sample() * pipe.vae.config.scaling_factor
                    video_low = video_low.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]

                    video_normal_latent_dist = pipe.vae.encode(video_normal).latent_dist
                    video_normal = video_normal_latent_dist.sample() * pipe.vae.config.scaling_factor
                    video_normal = video_normal.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]

                    video_event_latent_dist = pipe.vae.encode(video_event).latent_dist
                    video_event = video_event_latent_dist.sample() * pipe.vae.config.scaling_factor
                    video_event = video_event.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
                    if j == 0:
                        first_frame = None
                    else:
                        first_frame = all_generated_frames[-first_frame_idx]

                    if first_frame:
                        first_frame = transforms.ToTensor()(first_frame) * 255
                        first_frame = video_transforms(first_frame).unsqueeze(0)
                        first_frame = first_frame.unsqueeze(0).permute(0, 2, 1, 3, 4).to(device=device, dtype=dtype)
                        first_frame_latent_dist = pipe.vae.encode(first_frame).latent_dist
                        first_frame = first_frame_latent_dist.sample() * pipe.vae.config.scaling_factor
                        first_frame = first_frame.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
                    else:
                        first_frame = torch.zeros((video_event.shape[0], 1, video_event.shape[2], video_event.shape[3], video_event.shape[4]), device=device, dtype=dtype)
                    video_low = torch.cat([first_frame, video_low], dim=1)

                pipeline_args = {
                    "prompt": sample_j["prompt"],
                    "negative_prompt": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
                    "num_inference_steps": num_inference_steps,
                    "num_frames": 49,
                    "use_dynamic_cfg": True,
                    "guidance_scale": guidance_scale,
                    "generator": torch.Generator(device=device).manual_seed(seed),
                    "height": sample_j["height"],
                    "width": sample_j["width"]
                }

                # pipeline_args["image"] = (video_frame + 1.0) / 2.0
                pipeline_args["image"] = image

                pipeline_args["video_event"] = video_event
                pipeline_args["video_low"] = video_low
                pipeline_args["event_image"] = event_image

                # resize illumination map
                illumination_map = F.interpolate(illumination_map, size=(video_low.shape[1]-1, video_low.shape[3], video_low.shape[4]), mode="trilinear")
                pipeline_args["illumination_map"] = illumination_map

                with torch.no_grad():
                    video_generate = pipe(**pipeline_args).frames[0]
                
                if j == 0:
                    all_generated_frames.extend(video_generate)
                else:
                    all_generated_frames.extend(video_generate[first_frame_idx:])
                        
            output_dir = os.path.join(data_root, evaluation_dir)
            output_name = f"{seq_name}.mp4"
            output_file = os.path.join(output_dir, output_name)
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
            export_concat_video(all_generated_frames, torch.cat(all_video_lows, dim=0), torch.cat(all_video_events, dim=0), output_file, fps=fps)
            

def create_frame_grid(frames: List[np.ndarray], interval: int = 9, max_cols: int = 7) -> np.ndarray:
    """
    Arrange video frames into a grid image by sampling at intervals
    
    Args:
        frames: List of video frames
        interval: Sampling interval
        max_cols: Maximum number of frames per row
    
    Returns:
        Grid image array
    """
    # Sample frames at intervals
    sampled_frames = frames[::interval]
    
    # Calculate number of rows and columns
    n_frames = len(sampled_frames)
    n_cols = min(max_cols, n_frames)
    n_rows = (n_frames + n_cols - 1) // n_cols
    
    # Get height and width of single frame
    frame_height, frame_width = sampled_frames[0].shape[:2]
    
    # Create blank canvas
    grid = np.zeros((frame_height * n_rows, frame_width * n_cols, 3), dtype=np.uint8)
    
    # Fill frames
    for idx, frame in enumerate(sampled_frames):
        i = idx // n_cols
        j = idx % n_cols
        grid[i*frame_height:(i+1)*frame_height, j*frame_width:(j+1)*frame_width] = frame
    
    return grid

def export_concat_video(
    generated_frames: List[PIL.Image.Image], 
    original_video: torch.Tensor,
    event_maps: torch.Tensor = None,
    output_video_path: str = None,
    fps: int = 8
) -> str:
    """
    Export generated video frames, original video and event maps as video files,
    and save sampled frames to different folders
    """
    import imageio
    import os
    
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
        
    # Create subdirectories
    base_dir = os.path.dirname(output_video_path)
    generated_dir = os.path.join(base_dir, "generated")  # For storing generated videos
    group_dir = os.path.join(base_dir, "group")  # For storing concatenated videos
    
    # Get filename (without path) and create video-specific folder
    filename = os.path.basename(output_video_path)
    name_without_ext = os.path.splitext(filename)[0]
    video_frames_dir = os.path.join(base_dir, "frames", name_without_ext)  # frames/video_name/
    
    # Create three subdirectories under video-specific folder
    groundtruth_dir = os.path.join(video_frames_dir, "input_rgb")
    generated_frames_dir = os.path.join(video_frames_dir, "generated")
    event_dir = os.path.join(video_frames_dir, "input_event")
    
    # Create all required directories
    os.makedirs(generated_dir, exist_ok=True)
    os.makedirs(group_dir, exist_ok=True)
    os.makedirs(groundtruth_dir, exist_ok=True)
    os.makedirs(generated_frames_dir, exist_ok=True)
    os.makedirs(event_dir, exist_ok=True)
    
    # Convert original video tensor to numpy array and adjust format
    original_frames = []
    for frame in original_video:
        frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy()
        frame = ((frame + 1.0) * 127.5).astype(np.uint8)
        original_frames.append(frame)
    
    event_frames = []
    if event_maps is not None:
        for frame in event_maps:
            frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy()
            frame = ((frame + 1.0) * 127.5).astype(np.uint8)
            event_frames.append(frame)
    
    # Ensure all videos have same number of frames
    num_frames = min(len(generated_frames), len(original_frames))
    if event_maps is not None:
        num_frames = min(num_frames, len(event_frames))
    
    generated_frames = generated_frames[:num_frames]
    original_frames = original_frames[:num_frames]
    if event_maps is not None:
        event_frames = event_frames[:num_frames]
    
    # Convert generated PIL images to numpy arrays
    generated_frames_np = [np.array(frame) for frame in generated_frames]
    
    # Save generated video separately to generated folder
    gen_video_path = os.path.join(generated_dir, f"{name_without_ext}_generated.mp4")
    with imageio.get_writer(gen_video_path, fps=fps) as writer:
        for frame in generated_frames_np:
            writer.append_data(frame)
    
    # Concatenate frames vertically and save sampled frames
    concat_frames = []
    for i in range(num_frames):
        gen_frame = generated_frames_np[i]
        orig_frame = original_frames[i]
        
        width = min(gen_frame.shape[1], orig_frame.shape[1])
        height = orig_frame.shape[0]
        
        gen_frame = Image.fromarray(gen_frame).resize((width, height))
        gen_frame = np.array(gen_frame)
        orig_frame = Image.fromarray(orig_frame).resize((width, height))
        orig_frame = np.array(orig_frame)
        
        if event_maps is not None:
            track_frame = event_frames[i]
            track_frame = Image.fromarray(track_frame).resize((width, height))
            track_frame = np.array(track_frame)
            
            right_concat = np.concatenate([orig_frame, track_frame], axis=0)
            
            right_concat_pil = Image.fromarray(right_concat)
            new_height = right_concat.shape[0] // 2
            new_width = right_concat.shape[1] // 2
            right_concat_resized = right_concat_pil.resize((new_width, new_height))
            right_concat_resized = np.array(right_concat_resized)
            
            concat_frame = np.concatenate([gen_frame, right_concat_resized], axis=1)
        else:
            orig_frame_pil = Image.fromarray(orig_frame)
            new_height = orig_frame.shape[0] // 2
            new_width = orig_frame.shape[1] // 2
            orig_frame_resized = orig_frame_pil.resize((new_width, new_height))
            orig_frame_resized = np.array(orig_frame_resized)
            
            concat_frame = np.concatenate([gen_frame, orig_frame_resized], axis=1)
        
        concat_frames.append(concat_frame)
        
        # Save every 9 frames of each type of frame
        if i % 9 == 0:
            # Save generated frame
            gen_frame_path = os.path.join(generated_frames_dir, f"{i:04d}.png")
            Image.fromarray(gen_frame).save(gen_frame_path)
            
            # Save original frame
            gt_frame_path = os.path.join(groundtruth_dir, f"{i:04d}.png")
            Image.fromarray(orig_frame).save(gt_frame_path)
            
            # If event maps, save event frame
            if event_maps is not None:
                track_frame_path = os.path.join(event_dir, f"{i:04d}.png")
                Image.fromarray(track_frame).save(track_frame_path)
    
    # Export concatenated video to group folder
    group_video_path = os.path.join(group_dir, filename)
    with imageio.get_writer(group_video_path, fps=fps) as writer:
        for frame in concat_frames:
            writer.append_data(frame)
            
    return group_video_path

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
    parser.add_argument("--prompt", type=str, help="Optional: override the prompt from dataset")
    parser.add_argument(
        "--image_or_video_path",
        type=str,
        default=None,
        help="The path of the image to be used as the background of the video",
    )
    parser.add_argument(
        "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
    )
    parser.add_argument(
        "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
    )
    parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
    parser.add_argument(
        "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
    )
    parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
    parser.add_argument(
        "--generate_type", type=str, default="i2v", help="The type of video generation (e.g., 'i2v', 'i2vo')"
    )
    parser.add_argument(
        "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
    )
    parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
    
    # Dataset related parameters are required
    parser.add_argument("--data_root", type=str, required=True, help="Root directory of the dataset")
    
    # Add num_samples parameter
    parser.add_argument("--num_samples", type=int, default=-1, 
                       help="Number of samples to process. -1 means process all samples")
    
    # Add evaluation_dir parameter
    parser.add_argument("--evaluation_dir", type=str, default="evaluations", 
                       help="Name of the directory to store evaluation results")
    
    # Add fps parameter
    parser.add_argument("--fps", type=int, default=8, 
                       help="Frames per second for the output video")

    args = parser.parse_args()
    dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
    
    # If prompt is not provided, generate_video function will use prompts from dataset
    generate_video(
        model_path=args.model_path,
        image_or_video_path=args.image_or_video_path,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        dtype=dtype,
        seed=args.seed,
        data_root=args.data_root,
        num_samples=args.num_samples,
        evaluation_dir=args.evaluation_dir,
        fps=args.fps,
    )