import torch
from diffusers.utils import export_to_video
from module.attention_processor_cog import MyCogVideoXAttnProcessor2_0
from module.pipe_cog import myCogvideoXPipeline
from utils import plan_path, arg_to_bboxs, save_videos_with_bbox

import random 
import numpy as np

import argparse

torch.set_grad_enabled(False)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

# bg_prompt = "Hyper-realistic photography, a lush garden bathed in soft afternoon sunlight. Vibrant roses in red, pink, and yellow bloom densely on climbing trellises, while green ivy creeps up weathered stone walls. A small stone fountain gurgles gently in the center, with water rippling and reflecting the sky. Butterflies flit between lavender bushes, and a honeybee hovers above a daisy. The grass is neatly trimmed, with a winding gravel path."
# fg_prompt = "Realistic photography style, a medium-sized gray-and-white dog with fluffy fur running to the right. The dog has bright black eyes, perked ears, and a wagging tail. Its legs are in mid-stride, paws lifting off the ground, mouth slightly open as if panting. The camera follows the dog in a smooth tracking shot, capturing its energetic movement. Close shot from a low angle, emphasizing the dog's speed and vitality."
# base_prompt = "Realistic photography style, a medium-sized gray-and-white dog with fluffy fur running to the right. The dog has bright black eyes, perked ears, and a wagging tail. Its legs are in mid-stride, paws lifting off the ground, mouth slightly open as if panting. The background is a sunlit green lawn with a few scattered flowers. The camera follows the dog in a smooth tracking shot, capturing its energetic movement. Medium shot from a low angle, emphasizing the dog's speed and vitality."

# fg_prompt = "Close-up shot of a fluffy corgi in motion, its small body filling the frame as it bounds forward with dynamic energy. The dog's coat shimmers with soft texture, its bright eyes gleaming with joy, ears perked, and tail wagging vigorously. Legs stretch forward in mid-stride, paws lifting off the ground, mouth slightly open in a playful grin. The camera captures every detail of its fluffy fur and expressive face, emphasizing its speed and spirited movement. Natural lighting highlights the dog's features, creating a vivid, lifelike moment of pure canine delight."
# bg_prompt = "Expansive wide-angle view of a vast open grassland under a clear, bright blue sky. Rolling hills stretch endlessly, their slopes dotted with clusters of wildflowers in shades of yellow, purple, and white. The sky transitions seamlessly into the golden hues of the horizon, with gentle clouds drifting overhead. The grass is neatly trimmed, creating a smooth, open expanse that contrasts with the dog's energetic motion. A faint breeze stirs the wildflowers, adding subtle movement to the serene, open landscape."
# base_prompt = "A corgi running across a vast open grassland under a bright blue sky. The small dog has a fluffy coat, a wagging tail, and a joyful expression. It is mid-stride, showing dynamic movement with legs stretched forward. The background features rolling hills and scattered wildflowers. The scene is captured in a wide-angle shot, emphasizing the dog's speed and the expansive landscape. Natural motion with smooth camera movement following the corgi's run.\n"

fg_prompt = "a camel in desert, realistic photography style"
bg_prompt = "a vast desert with sand dunes under a clear blue sky, realistic photography style"
base_prompt = "a camel running across a vast desert with sand dunes under a clear blue sky"
bboxs = [
            [0, 0.2, 0.7, 0.0, 0.4],
            [48, 0.3, 0.8, 0.4, 0.8]
        ]


bboxs_flat = [str(num) for bbox in bboxs for num in bbox]  # 展平列表并转为字符串
bboxs_arg = ",".join(bboxs_flat)  # 拼接为单个字符串参数


parser = argparse.ArgumentParser()
parser.add_argument("--seed", default=66, type=int)
parser.add_argument("--mask_step", default=50, type=int)
parser.add_argument("--base_ratio", default=0.0, type=float)
parser.add_argument("--bg_prompt", default=bg_prompt, type=str)
parser.add_argument("--fg_prompt", default=fg_prompt, type=str)
parser.add_argument("--base_prompt", default=base_prompt, type=str)
parser.add_argument("--negative_prompt", default=negative_prompt, type=str)
parser.add_argument("--output_path", default='output_cog.mp4', type=str)
parser.add_argument("--output_path_withbox", default='cog_output_box', type=str)
parser.add_argument("--bboxs_arg", default=bboxs_arg, type=str, help="bboxs参数，5个一组，逗号分隔，如0,0.5,0.8,0.2,0.5,48,0.5,0.8,0.5,0.8")
parser.add_argument("--initTraj", action="store_true")
parser.add_argument("--fixRope_step", default=7, type=int)
parser.add_argument("--selfmask_step", default=4, type=int)
parser.add_argument("--num_frame", default=49, type=int)
parser.add_argument("--height", default=480, type=int)
parser.add_argument("--width", default=720, type=int)

args = parser.parse_args()

seed = args.seed
mask_step = args.mask_step
bg_prompt = args.bg_prompt
fg_prompt = args.fg_prompt
base_prompt = args.base_prompt
negative_prompt = args.negative_prompt
bboxs = arg_to_bboxs(args.bboxs_arg)
num_frame = args.num_frame
height = args.height
width = args.width
fixRope_step = args.fixRope_step
selfmask_step = args.selfmask_step


print(f"Using seed: {seed}, mask_step: {mask_step}, fixRope_step: {fixRope_step}, bboxs: {bboxs}")

output_path = args.output_path
output_path_withbox = args.output_path_withbox


model_id = "CogVideoX-5b"
pipe = myCogvideoXPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)


random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
generator=torch.Generator(device='cpu').manual_seed(seed)

attn_procs = {}
cnt = 0
for name in pipe.transformer.attn_processors.keys():
    cnt+=1
    attn_procs[name] = MyCogVideoXAttnProcessor2_0()
pipe.transformer.set_attn_processor(attn_procs)
print(f"******{cnt} attn_procs changed.*******")
device = 'cuda'
pipe.to(device)


latent_num_frame = (num_frame-1)//4 + 1
bbox_h = height // 16
bbox_w = width // 16

bbox_mask = torch.zeros([latent_num_frame, 1, bbox_h, bbox_w]).to(device)


# dynamic box
PATHS = plan_path(bboxs, video_length=num_frame)[::4]
assert latent_num_frame == len(PATHS), "latent_num_frame != len(PATHS)"
for i in range(latent_num_frame):
    h_start = int(PATHS[i][0] * bbox_h)
    h_end = int(PATHS[i][1] * bbox_h)
    w_start = int(PATHS[i][2] * bbox_w)
    w_end = int(PATHS[i][3] * bbox_w)
    bbox_mask[i, :, h_start:h_end, w_start:w_end] = 1


latents = None

encoder_attention_mask = torch.Tensor([False for i in range(226)] + [True for i in range(226)])
encoder_attention_mask = encoder_attention_mask.to(device)


output = pipe(
     prompt=base_prompt,
     negative_prompt=negative_prompt,
     height=height,
     width=width,
     num_frames=num_frame,
     guidance_scale=6,
     generator=generator,
     attention_kwargs={"bbox_mask": bbox_mask,"encoder_attention_mask":encoder_attention_mask,"bg_prompt":bg_prompt,"fg_prompt":fg_prompt,"fixRope_step":fixRope_step, "mask_step":mask_step, "selfmask_step":selfmask_step},
     latents=latents,
    ).frames[0]

export_to_video(output, output_path, fps=num_frame//4)



if output_path_withbox is not None:
    from torchvision import transforms
    to_tensor = transforms.ToTensor()

    output_tensor = []

    for frame in output:
        output_tensor.append(to_tensor(frame))
    output_tensor = torch.stack(output_tensor, dim=0)
    save_videos_with_bbox(output_tensor.unsqueeze(0).unsqueeze(0).permute(0,1,3,2,4,5), output_path_withbox, fps=num_frame//4, input_traj=bboxs)