import sys
import os
sys.path.insert(0, os.getcwd())
sys.path.append('.')
sys.path.append('..')
import argparse
import os

import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import (
    CogVideoXDDIMScheduler,
    CogVideoXDPMScheduler,
    AutoencoderKLCogVideoX
)
from diffusers.utils import export_to_video, load_video

from controlnet_pipeline import ControlnetCogVideoXImageToVideoPCDPipeline
from cogvideo_transformer import CustomCogVideoXTransformer3DModel
from cogvideo_controlnet_pcd import CogVideoXControlnetPCD
from training.controlnet_datasets_camera_pcd_mask import RealEstate10KPCDRenderDataset
from torchvision.transforms.functional import to_pil_image

from inference.utils import stack_images_horizontally
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import cv2

import cv2
import numpy as np
import torch
import torch.nn.functional as F

import cv2
import numpy as np
import torch

def get_black_region_mask_tensor(video_tensor, threshold=2, kernel_size=15):
    """
    Generate cleaned binary masks for black regions in a video tensor.
    
    Args:
        video_tensor (torch.Tensor): shape (T, H, W, 3), RGB, uint8
        threshold (int): pixel intensity threshold to consider a pixel as black (default: 20)
        kernel_size (int): morphological kernel size to smooth masks (default: 7)
    
    Returns:
        torch.Tensor: binary mask tensor of shape (T, H, W), where 1 indicates black region
    """
    video_uint8 = ((video_tensor + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)  # shape (T, H, W, C)
    video_np = video_uint8.numpy()

    T, H, W, _ = video_np.shape
    masks = np.empty((T, H, W), dtype=np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))

    for t in range(T):
        img = video_np[t]  # (H, W, 3), uint8
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        _, mask = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY_INV)
        mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        masks[t] = (mask_cleaned > 0).astype(np.uint8)
    return torch.from_numpy(masks)

def maxpool_mask_tensor(mask_tensor):
    """
    Apply spatial and temporal max pooling to a binary mask tensor.
    
    Args:
        mask_tensor (torch.Tensor): shape (T, H, W), binary mask (0 or 1)
    
    Returns:
        torch.Tensor: shape (12, 30, 45), pooled binary mask
    """
    T, H, W = mask_tensor.shape
    assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
    assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"

    # Reshape to (B=T, C=1, H, W) for 2D spatial pooling
    x = mask_tensor.unsqueeze(1).float()  # (T, 1, H, W)
    x_pooled = F.max_pool2d(x, kernel_size=(H // 30, W // 45))  # → (T, 1, 30, 45)

    # Temporal pooling: reshape to (12, T//12, 30, 45) and max along dim=1
    t_groups = T // 12
    x_pooled = x_pooled.view(12, t_groups, 30, 45)
    pooled_mask = torch.amax(x_pooled, dim=1)  # → (12, 30, 45)

    # Add a zero frame at the beginning: shape (1, 30, 45)
    zero_frame = torch.zeros_like(pooled_mask[0:1])  # (1, 30, 45)
    pooled_mask = torch.cat([zero_frame, pooled_mask], dim=0)  # → (13, 30, 45)
    
    return 1 - pooled_mask.int()

def avgpool_mask_tensor(mask_tensor):
    """
    Apply spatial and temporal average pooling to a binary mask tensor,
    and threshold at 0.5 to retain only majority-active regions.
    
    Args:
        mask_tensor (torch.Tensor): shape (T, H, W), binary mask (0 or 1)
    
    Returns:
        torch.Tensor: shape (13, 30, 45), pooled binary mask with first frame zeroed
    """
    T, H, W = mask_tensor.shape
    assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
    assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"

    # Spatial average pooling
    x = mask_tensor.unsqueeze(1).float()  # (T, 1, H, W)
    x_pooled = F.avg_pool2d(x, kernel_size=(H // 30, W // 45))  # → (T, 1, 30, 45)

    # Temporal pooling
    t_groups = T // 12
    x_pooled = x_pooled.view(12, t_groups, 30, 45)
    pooled_avg = torch.mean(x_pooled, dim=1)  # → (12, 30, 45)

    # Threshold: keep only when > 0.5
    pooled_mask = (pooled_avg > 0.5).int()

    # Add zero frame
    zero_frame = torch.zeros_like(pooled_mask[0:1])
    pooled_mask = torch.cat([zero_frame, pooled_mask], dim=0)  # → (13, 30, 45)

    return 1 - pooled_mask  # inverting as before

@torch.no_grad()
def generate_video(
    prompt,
    image,
    video_root_dir: str,
    base_model_path: str,
    use_zero_conv: bool,
    controlnet_model_path: str,
    controlnet_weights: float = 1.0,
    controlnet_guidance_start: float = 0.0,
    controlnet_guidance_end: float = 1.0,
    use_dynamic_cfg: bool = True,
    lora_path: str = None,
    lora_rank: int = 128,
    output_path: str = "./output/",
    num_inference_steps: int = 50,
    guidance_scale: float = 6.0,
    num_videos_per_prompt: int = 1,
    dtype: torch.dtype = torch.bfloat16,
    seed: int = 42,
    num_frames: int = 49,
    height: int = 480,
    width: int = 720,
    start_camera_idx: int = 0,
    end_camera_idx: int = 1,
    controlnet_transformer_num_attn_heads: int = None,
    controlnet_transformer_attention_head_dim: int = None,
    controlnet_transformer_out_proj_dim_factor: int = None,
    controlnet_transformer_out_proj_dim_zero_init: bool = False,
    controlnet_transformer_num_layers: int = 8,
    downscale_coef: int = 8,
    controlnet_input_channels: int = 6,
    infer_with_mask: bool = False,
    pool_style: str = 'avg',
    pipe_cpu_offload: bool = False,
):
    """
    Generates a video based on the given prompt and saves it to the specified path.

    Parameters:
    - prompt (str): The description of the video to be generated.
    - video_root_dir (str): The path to the camera dataset
    - annotation_json (str): Name of subset (train.json or test.json)
    - base_model_path (str): The path of the pre-trained model to be used.
    - controlnet_model_path (str): The path of the pre-trained conrolnet model to be used.
    - controlnet_weights (float): Strenght of controlnet
    - controlnet_guidance_start (float): The stage when the controlnet starts to be applied
    - controlnet_guidance_end (float): The stage when the controlnet end to be applied
    - lora_path (str): The path of the LoRA weights to be used.
    - lora_rank (int): The rank of the LoRA weights.
    - output_path (str): The path where the generated video will be saved.
    - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
    - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
    - num_videos_per_prompt (int): Number of videos to generate per prompt.
    - dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
    - seed (int): The seed for reproducibility.
    """
    os.makedirs(output_path, exist_ok=True)
    # 1.  Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
    tokenizer = T5Tokenizer.from_pretrained(
        base_model_path, subfolder="tokenizer"
    )
    text_encoder = T5EncoderModel.from_pretrained(
        base_model_path, subfolder="text_encoder"
    )
    transformer = CustomCogVideoXTransformer3DModel.from_pretrained(
        base_model_path, subfolder="transformer"
    )
    vae = AutoencoderKLCogVideoX.from_pretrained(
        base_model_path, subfolder="vae"
    )
    scheduler = CogVideoXDDIMScheduler.from_pretrained(
        base_model_path, subfolder="scheduler"
    )
    # ControlNet
    num_attention_heads_orig = 48 if "5b" in base_model_path.lower() else 30
    controlnet_kwargs = {}
    if controlnet_transformer_num_attn_heads is not None:
        controlnet_kwargs["num_attention_heads"] = args.controlnet_transformer_num_attn_heads
    else:
        controlnet_kwargs["num_attention_heads"] = num_attention_heads_orig
    if controlnet_transformer_attention_head_dim is not None:
        controlnet_kwargs["attention_head_dim"] = controlnet_transformer_attention_head_dim
    if controlnet_transformer_out_proj_dim_factor is not None:
        controlnet_kwargs["out_proj_dim"] = num_attention_heads_orig * controlnet_transformer_out_proj_dim_factor
    controlnet_kwargs["out_proj_dim_zero_init"] = controlnet_transformer_out_proj_dim_zero_init
    controlnet = CogVideoXControlnetPCD(
        num_layers=controlnet_transformer_num_layers,
        downscale_coef=downscale_coef,
        in_channels=controlnet_input_channels,
        use_zero_conv=use_zero_conv,
        **controlnet_kwargs,   
    )
    if controlnet_model_path:
        ckpt = torch.load(controlnet_model_path, map_location='cpu', weights_only=False)
        controlnet_state_dict = {}
        for name, params in ckpt['state_dict'].items():
            controlnet_state_dict[name] = params
        m, u = controlnet.load_state_dict(controlnet_state_dict, strict=False)
        print(f'[ Weights from pretrained controlnet was loaded into controlnet ] [M: {len(m)} | U: {len(u)}]')
    
    # Full pipeline
    pipe = ControlnetCogVideoXImageToVideoPCDPipeline(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        transformer=transformer,
        vae=vae,
        controlnet=controlnet,
        scheduler=scheduler,
    ).to('cuda')
    # If you're using with lora, add this code
    if lora_path:
        pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
        pipe.fuse_lora(lora_scale=1 / lora_rank)

    # 2. Set Scheduler.
    # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
    # We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
    # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.

    # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
    pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    # 3. Enable CPU offload for the model.
    # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
    # and enable to("cuda")

    # pipe.to("cuda")
    pipe = pipe.to(dtype=dtype)
    # pipe.enable_sequential_cpu_offload()
    if pipe_cpu_offload:
        pipe.enable_model_cpu_offload()

    pipe.vae.enable_slicing()
    pipe.vae.enable_tiling()
    
    # 4. Load dataset
    eval_dataset = RealEstate10KPCDRenderDataset(
        video_root_dir=video_root_dir,
        image_size=(height, width), 
        sample_n_frames=num_frames,
    )
    
    None_prompt = True
    if prompt:
        None_prompt = False
    print(eval_dataset.dataset)
    
    for camera_idx in range(start_camera_idx, end_camera_idx):
        # Get data
        data_dict = eval_dataset[camera_idx]
        reference_video = data_dict['video']
        anchor_video = data_dict['anchor_video']
        print(eval_dataset.dataset[camera_idx],seed)
        
        if None_prompt:
            # Set output directory
            output_path_file = os.path.join(output_path, f"{camera_idx:05d}_{seed}_out.mp4")
            prompt = data_dict['caption']
        else:
            # Set output directory
            output_path_file = os.path.join(output_path, f"{prompt[:10]}_{camera_idx:05d}_{seed}_out.mp4")

        if image is None:
            input_images = reference_video[0].unsqueeze(0)
        else:
            input_images = torch.tensor(np.array(Image.open(image))).permute(2,0,1).unsqueeze(0)/255
            pixel_transforms = [transforms.Resize((480, 720)),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
            for transform in pixel_transforms:
                input_images = transform(input_images)

        # if image is None:
        #     input_images = reference_video[:24]
        # else:
        #     input_images = torch.tensor(np.array(Image.open(image))).permute(2,0,1)/255
        #     pixel_transforms = [transforms.Resize((480, 720)),
        #                         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
        #     for transform in pixel_transforms:
        #         input_images = transform(input_images)
            
        reference_frames = [to_pil_image(frame) for frame in ((reference_video)/2+0.5)]
        controlnet_latents = data_dict['controlnet_video'].to('cuda')
        controlnet_latents = controlnet_latents[None]
        
        output_path_file_reference = output_path_file.replace("_out.mp4", "_reference.mp4")
        output_path_file_out_reference = output_path_file.replace(".mp4", "_reference.mp4")
        
        if infer_with_mask:
            try:
                video_mask = 1 - torch.from_numpy(np.load(os.path.join(eval_dataset.root_path,'masks',eval_dataset.dataset[camera_idx]+'.npz'))['mask']*1)
            except:
                print('using derived mask')
                video_mask = get_black_region_mask_tensor(anchor_video)
            
            if pool_style == 'max':
                controlnet_output_mask = maxpool_mask_tensor(video_mask[1:]).flatten().unsqueeze(0).unsqueeze(-1).to('cuda')
            elif pool_style == 'avg':
               controlnet_output_mask = avgpool_mask_tensor(video_mask[1:]).flatten().unsqueeze(0).unsqueeze(-1).to('cuda')
        else:
            controlnet_output_mask = None
        # if os.path.isfile(output_path_file):
        #     continue
        
        # 5. Generate the video frames based on the prompt.
        # `num_frames` is the Number of frames to generate.
        # This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
        video_generate_all = pipe(
            image=input_images,
            anchor_video=anchor_video,
            controlnet_output_mask=controlnet_output_mask,
            prompt=prompt,
            controlnet_latents=controlnet_latents,  # The path of the image to be used as the background of the video
            num_videos_per_prompt=num_videos_per_prompt,  # Number of videos to generate per prompt
            num_inference_steps=num_inference_steps,  # Number of inference steps
            num_frames=num_frames,  # Number of frames to generate，changed to 49 for diffusers version `0.30.3` and after.
            use_dynamic_cfg=use_dynamic_cfg,  # This id used for DPM Sechduler, for DDIM scheduler, it should be False
            guidance_scale=guidance_scale,
            generator=torch.Generator().manual_seed(seed),  # Set the seed for reproducibility
            controlnet_weights=controlnet_weights,
            controlnet_guidance_start=controlnet_guidance_start,
            controlnet_guidance_end=controlnet_guidance_end,
        ).frames
        video_generate = video_generate_all[0]

        # 6. Export the generated frames to a video file. fps must be 8 for original video.
        export_to_video(video_generate, output_path_file, fps=8)
        export_to_video(reference_frames, output_path_file_reference, fps=8)
        out_reference_frames = [
            stack_images_horizontally(frame_reference, frame_out)
            for frame_out, frame_reference in zip(video_generate, reference_frames)
            ]
        
        anchor_video = [to_pil_image(frame) for frame in ((anchor_video)/2+0.5)]
        out_reference_frames = [
            stack_images_horizontally(frame_out, frame_reference)
            for frame_out, frame_reference in zip(out_reference_frames, anchor_video)
            ]
        export_to_video(out_reference_frames, output_path_file_out_reference, fps=8)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
    parser.add_argument("--prompt", type=str, default=None, help="The description of the video to be generated")
    parser.add_argument("--image", type=str, default=None, help="The reference image of the video to be generated")
    parser.add_argument(
        "--video_root_dir",
        type=str,
        required=True,
        help="The path of the video for controlnet processing.",
    )
    parser.add_argument(
        "--base_model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
    )
    parser.add_argument(
        "--controlnet_model_path", type=str, default="TheDenk/cogvideox-5b-controlnet-hed-v1", help="The path of the controlnet pre-trained model to be used"
    )
    parser.add_argument("--controlnet_weights", type=float, default=0.5, help="Strenght of controlnet")
    parser.add_argument("--use_zero_conv", action="store_true", default=False, help="Use zero conv")
    parser.add_argument("--infer_with_mask", action="store_true", default=False, help="add mask to controlnet")
    parser.add_argument("--pool_style", default='max', help="max pool or avg pool")
    parser.add_argument("--controlnet_guidance_start", type=float, default=0.0, help="The stage when the controlnet starts to be applied")
    parser.add_argument("--controlnet_guidance_end", type=float, default=0.5, help="The stage when the controlnet end to be applied")
    parser.add_argument("--use_dynamic_cfg", type=bool, default=True, help="Use dynamic cfg")
    parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
    parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
    parser.add_argument(
        "--output_path", type=str, default="./output", help="The path where the generated video will be saved"
    )
    parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
    parser.add_argument(
        "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
    )
    parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
    parser.add_argument(
        "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
    )
    parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
    parser.add_argument("--height", type=int, default=480)
    parser.add_argument("--width", type=int, default=720)
    parser.add_argument("--num_frames", type=int, default=49)
    parser.add_argument("--start_camera_idx", type=int, default=0)
    parser.add_argument("--end_camera_idx", type=int, default=1)
    parser.add_argument("--controlnet_transformer_num_attn_heads", type=int, default=None)
    parser.add_argument("--controlnet_transformer_attention_head_dim", type=int, default=None)
    parser.add_argument("--controlnet_transformer_out_proj_dim_factor", type=int, default=None)
    parser.add_argument("--controlnet_transformer_out_proj_dim_zero_init", action="store_true", default=False, help=("Init project zero."),
    )
    parser.add_argument("--downscale_coef", type=int, default=8)
    parser.add_argument("--vae_channels", type=int, default=16)
    parser.add_argument("--controlnet_input_channels", type=int, default=6)
    parser.add_argument("--controlnet_transformer_num_layers", type=int, default=8)
    parser.add_argument("--enable_model_cpu_offload", action="store_true", default=False, help="Enable model CPU offload")

    args = parser.parse_args()
    dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
    generate_video(
        prompt=args.prompt,
        image=args.image,
        video_root_dir=args.video_root_dir,
        base_model_path=args.base_model_path,
        use_zero_conv=args.use_zero_conv,
        controlnet_model_path=args.controlnet_model_path,
        controlnet_weights=args.controlnet_weights,
        controlnet_guidance_start=args.controlnet_guidance_start,
        controlnet_guidance_end=args.controlnet_guidance_end,
        use_dynamic_cfg=args.use_dynamic_cfg,
        lora_path=args.lora_path,
        lora_rank=args.lora_rank,
        output_path=args.output_path,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        num_videos_per_prompt=args.num_videos_per_prompt,
        dtype=dtype,
        seed=args.seed,
        height=args.height,
        width=args.width,
        num_frames=args.num_frames,
        start_camera_idx=args.start_camera_idx,
        end_camera_idx=args.end_camera_idx,
        controlnet_transformer_num_attn_heads=args.controlnet_transformer_num_attn_heads,
        controlnet_transformer_attention_head_dim=args.controlnet_transformer_attention_head_dim,
        controlnet_transformer_out_proj_dim_factor=args.controlnet_transformer_out_proj_dim_factor,
        controlnet_transformer_num_layers=args.controlnet_transformer_num_layers,
        downscale_coef=args.downscale_coef,
        controlnet_input_channels=args.controlnet_input_channels,
        infer_with_mask=args.infer_with_mask,
        pool_style=args.pool_style,
        pipe_cpu_offload=args.enable_model_cpu_offload,
    )
