import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys

import cv2
import numpy as np
import torch
from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
                       DPMSolverMultistepScheduler,
                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
                       PNDMScheduler)
from PIL import Image
from transformers import T5EncoderModel

current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
    sys.path.insert(0, project_root) if project_root not in sys.path else None

from videox_fun.models import (AutoencoderKLCogVideoX,
                              CogVideoXTransformer3DModel, T5EncoderModel,
                              T5Tokenizer)
from videox_fun.pipeline import (CogVideoXFunControlPipeline,
                                CogVideoXFunInpaintPipeline,
                                CogVideoXFunHIAPipeline)
from videox_fun.utils.fp8_optimization import convert_weight_dtype_wrapper
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from videox_fun.utils.utils import get_video_to_video_latent, save_videos_grid, get_image_latent
from videox_fun.dist import set_multi_gpus_devices

# GPU memory mode, which can be choosen in [model_full_load, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
# model_full_load means that the entire model will be moved to the GPU.
# 
# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
# 
# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, 
# and the transformer model has been quantized to float8, which can save more GPU memory. 
# 
# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, 
# resulting in slower speeds but saving a large amount of GPU memory.
GPU_memory_mode     = "model_cpu_offload" # "model_full_load" # "model_cpu_offload" # "model_cpu_offload_and_qfloat8"
# Multi GPUs config
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. 
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
ulysses_degree      = 1
ring_degree         = 1

# model path
model_name          = "models/Diffusion_Transformer/CogVideoX-Fun-V1.5-5b-InP"

# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" "DDIM_Cog" and "DDIM_Origin"
sampler_name        = "DDIM_Origin"

# Load pretrained model if need
# transformer_path    = None
transformer_path    = "./output_dir-stage2/checkpoint-30000/output_dir-bf16/diffusion_pytorch_model.safetensors"
vae_path            = None
lora_path           = None
# Other params
sample_size         = [672, 384]
video_length        = 81
fps                 = 16

# Use torch.float16 if GPU does not support torch.bfloat16
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
weight_dtype            = torch.bfloat16


DATA_ROOT = "./dataset/stage2"
DATA_NAME = "video_20230927_7283652378434342150"
prompt = "The image shows a person standing indoors, likely female, wearing a black long-sleeve top, gray sweatpants with white stripes, and a black baseball cap with a white \"I\" on it. The individual has long, wavy brown hair. The background features a white door and a window with blinds, partially covered by a gray blanket hanging on the left side of the frame. The setting appears to be a home interior."
driving_path = f"{DATA_ROOT}/Unseen/pseudo_driving/{DATA_NAME}.mp4"
driving_mask_path = f"{DATA_ROOT}/Unseen/driving_mask-fine_grained-refined/{DATA_NAME}.mp4"
driving_face_path = f"{DATA_ROOT}/Unseen/pseudo_driving_face/{DATA_NAME}.mp4"
driving_face_mask_path = f"{DATA_ROOT}/Unseen/driving_face_mask/{DATA_NAME}.mp4"
reference_file_path = f"{DATA_ROOT}/Unseen/reference/{DATA_NAME}.png"


# prompts
negative_prompt         = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
guidance_scale          = 6.0
seed                    = 43
num_inference_steps     = 50
lora_weight             = 0.55
save_path               = "samples/HIA_stage2"
# '''
device = set_multi_gpus_devices(ulysses_degree, ring_degree)

transformer = CogVideoXTransformer3DModel.from_pretrained(
    model_name, 
    subfolder="transformer",
    low_cpu_mem_usage=True,
    torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
).to(weight_dtype)

if transformer_path is not None:
    print(f"From checkpoint: {transformer_path}")
    if transformer_path.endswith("safetensors"):
        from safetensors.torch import load_file, safe_open
        state_dict = load_file(transformer_path)
    else:
        state_dict = torch.load(transformer_path, map_location="cpu")
    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict

    m, u = transformer.load_state_dict(state_dict, strict=False)
    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")

# Get Vae
vae = AutoencoderKLCogVideoX.from_pretrained(
    model_name, 
    subfolder="vae"
).to(weight_dtype)

if vae_path is not None:
    print(f"From checkpoint: {vae_path}")
    if vae_path.endswith("safetensors"):
        from safetensors.torch import load_file, safe_open
        state_dict = load_file(vae_path)
    else:
        state_dict = torch.load(vae_path, map_location="cpu")
    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict

    m, u = vae.load_state_dict(state_dict, strict=False)
    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")

# Get tokenizer and text_encoder
tokenizer = T5Tokenizer.from_pretrained(
    model_name, subfolder="tokenizer"
)
text_encoder = T5EncoderModel.from_pretrained(
    model_name, subfolder="text_encoder", torch_dtype=weight_dtype
)

# Get Scheduler
Choosen_Scheduler = scheduler_dict = {
    "Euler": EulerDiscreteScheduler,
    "Euler A": EulerAncestralDiscreteScheduler,
    "DPM++": DPMSolverMultistepScheduler, 
    "PNDM": PNDMScheduler,
    "DDIM_Cog": CogVideoXDDIMScheduler,
    "DDIM_Origin": DDIMScheduler,
}[sampler_name]
scheduler = Choosen_Scheduler.from_pretrained(
    model_name, 
    subfolder="scheduler"
)

pipeline = CogVideoXFunHIAPipeline(
    vae=vae,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    transformer=transformer,
    scheduler=scheduler,
)
if ulysses_degree > 1 or ring_degree > 1:
    transformer.enable_multi_gpus_inference()

if GPU_memory_mode == "sequential_cpu_offload":
    pipeline.enable_sequential_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
    convert_weight_dtype_wrapper(transformer, weight_dtype)
    pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload":
    pipeline.enable_model_cpu_offload(device=device)
else:
    pipeline.to(device=device)

generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
    pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)


video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
if video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
    additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
    video_length += additional_frames * vae.config.temporal_compression_ratio
# '''

# video_length = 25
driving_video, _, _, _ = get_video_to_video_latent(driving_path, video_length=video_length, sample_size=sample_size, fps=fps)
driving_mask_video, _, _, _ = get_video_to_video_latent(driving_mask_path, video_length=video_length, sample_size=sample_size, fps=fps)
driving_mask_video = driving_mask_video[:, 0:1, :, :, :]  
driving_face_video, _, _, _ = get_video_to_video_latent(driving_face_path, video_length=video_length, sample_size=sample_size, fps=fps)
driving_face_mask_video, _, _, _ = get_video_to_video_latent(driving_face_mask_path, video_length=video_length, sample_size=sample_size, fps=fps)
driving_face_mask_video = driving_face_mask_video[:, 0:1, :, :, :]  
ref_image = get_image_latent(ref_image=reference_file_path, sample_size=sample_size)
ref_image = ref_image.expand(-1, -1, video_length, -1, -1).clone() 
ref_image[:, :, 1:, :, :] = 0

with torch.no_grad():
    sample = pipeline(
        prompt, 
        num_frames = video_length,
        negative_prompt = negative_prompt,
        height      = sample_size[0],
        width       = sample_size[1],
        generator   = generator,
        guidance_scale = guidance_scale,
        num_inference_steps = num_inference_steps,

        driving_video = driving_video,
        driving_mask_video = driving_mask_video,
        driving_face_video = driving_face_video,
        driving_face_mask_video = driving_face_mask_video,
        ref_image = ref_image,
    ).videos

if lora_path is not None:
    pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)

def save_results():
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)

    index = len([path for path in os.listdir(save_path)]) + 1
    prefix = str(index).zfill(8)
    if video_length == 1:
        video_path = os.path.join(save_path, prefix + ".png")

        image = sample[0, :, 0]
        image = image.transpose(0, 1).transpose(1, 2)
        image = (image * 255).numpy().astype(np.uint8)
        image = Image.fromarray(image)
        image.save(video_path)
    else:
        video_path = os.path.join(save_path, prefix + ".mp4")
        save_videos_grid(sample, video_path, fps=fps)

if ulysses_degree * ring_degree > 1:
    import torch.distributed as dist
    if dist.get_rank() == 0:
        save_results()
else:
    save_results()