import argparse
from collections import OrderedDict

import os
from pathlib import Path
import json
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.functional
import torchvision.io
import torchvision.transforms as T
import torchvision.transforms.functional as transforms_f
import torchvision.utils as ttf
from PIL import Image
from diffusers import (
    AutoencoderKL,
    KDPM2DiscreteScheduler,
    PNDMScheduler,
    ControlNetModel,
)
from diffusers.utils.torch_utils import randn_tensor
from insightface.app import FaceAnalysis
from omegaconf import OmegaConf
from safetensors.torch import load_file

from transformers import (
    CLIPVisionModelWithProjection,
    Wav2Vec2Model,
    Wav2Vec2Processor,
    CLIPTextModel,
    CLIPTokenizer,
    CLIPVisionModel,
    CLIPImageProcessor,
)

from modules import (
    UNet2DConditionModel,
    UNet3DConditionModel,
    VKpsGuider,
    AudioProjection,
    UNetMotionModel,
    MotionAdapter,
    T2IAdapter,
    DDIMScheduler,
)
from pipelines.v_express_pipeline_prefix_meanvar_face import VExpressPipelinePrefixMeanVarFace
from pipelines.utils import (
    draw_kps_image,
    save_video,
    video_to_pil_images,
    extract_kps_img,
)
from pipelines.context import compute_num_context, compute_context_indices
from utils.utils import (
    check_zero_initialization,
    get_module_params,
    print_highlighted_block_log,
    load_img,
    load_masked_image,
)
from utils.adapter_utils import load_motion_modules, t2i_adapter_map_keys, load_dreambooth_weights
from utils.hook_utils import register_attention_hook

import modules.adapter.face_adapter.model_seg_unet as model_seg_unet
from modules.adapter.face_adapter.model_to_token import Image2Token, ID2Token
from modules.adapter.face_adapter.utils import draw_pts3_batch, draw_pts70_batch, pil2tensor, transformation_from_points, get_box_lm4p, mean_box_lm4p_512, get_affine_transform, mean_face_lm5p_256

import modules.third_party.model_resnet_d3dfr as model_resnet_d3dfr
import modules.third_party.d3dfr.bfm as bfm
import modules.third_party.insightface_backbone_conv as model_insightface_backbone

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--unet_config_path", type=str, default="/dockerdata/models/stable-diffusion-v1-5/unet/config.json ",)
    parser.add_argument("--vae_path", type=str, default="/dockerdata/models/sd-vae-ft-mse/")
    parser.add_argument("--audio_encoder_path", type=str, default="./model_ckpts/wav2vec2-base-960h/")
    parser.add_argument("--insightface_model_path", type=str, default="/root/models/insightface_models/",)
    parser.add_argument("--denoising_unet_path", type=str, default=None)
    parser.add_argument("--v_kps_guider_path", type=str, default=None)
    parser.add_argument("--audio_projection_path", type=str, default=None)
    parser.add_argument("--motion_module_path", type=str, default="")
    parser.add_argument("--lora_setting_path", type=str, default="./configs/infer/extra_lora_settings.json")
    parser.add_argument("--ip_ckpt", type=str, default=None)
    parser.add_argument("--image_encoder_path", type=str, default="/root/models/IP-Adapter/models/image_encoder/",)
    parser.add_argument("--sd_model_name", type=str, default="/dockerdata/models/stable-diffusion-v1-5/")
    parser.add_argument("--motion_adapter_path", type=str, default=None)
    parser.add_argument("--t2i_adapter_model_path", type=str, default=None)
    parser.add_argument("--dreambooth_path", type=str, default=None)
    parser.add_argument('-f_ckpt', '--face_checkpoint', type=str, default='/root/models/FaceAdapter')
    parser.add_argument('-r', '--crop_ratio', type=float, default=0.81)

    parser.add_argument(
        "--retarget_strategy",
        type=str,
        default="fix_face",
        help="choose one from {fix_face, no_retarget, offset_retarget, naive_retarget}",
    )

    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--dtype", type=str, default="fp16")

    parser.add_argument("--num_pad_audio_frames", type=int, default=2)
    parser.add_argument("--aud_depth", type=int, default=4)
    parser.add_argument("--standard_audio_sampling_rate", type=int, default=16000)

    parser.add_argument("--reference_image_path", type=str, default=None,)
    parser.add_argument("--refbg_image_path", type=str, default=None,)
    parser.add_argument("--audio_path", type=str, default=None)
    parser.add_argument("--kps_path", type=str, default=None)
    parser.add_argument("--prompt", type=str, default="best quality, high quality")
    parser.add_argument(
        "--negative_prompt",
        type=str,
        default="(low quality, worst quality:1.5), (3d, render, cgi, doll, painting, fake, cartoon, 3d modeling:1.4),(mole:1.2) (worst quality, low quality:1.4), monochrome, deformed, malformed,(dark circles:1.2) deformed face, bad teeth, bad hands, bad fingers, bad eyes, long body,torn clothes, blurry, duplicate, cloned, duplicate body parts, disfigured, extra limbs, fused fingers, extra fingers, twisted, distorted, malformed hands, mutated hands and fingers, conjoined, missing limbs, bad anatomy, bad proportions, logo, watermark, text, copyright, signature, lowres, mutated, mutilated, artifacts, gross, ugly,large head, large face, ",
    )  # Monochrome, lowres, blurry, bad anatomy, distortions, poor lighting, dull colors, off-theme elements, inappropriate content, unwanted artifacts, unsettling moods
    parser.add_argument("--output_path", type=str, default=None)

    parser.add_argument("--ip_mode", default=None, type=str, help='The loaded IP-Adapter mode')
    parser.add_argument(
        "--t2i_adapter_control_type",
        default=None,
        nargs='*',
        type=str,
        help="choose one from {kps, openpose, mask, reference}",
    )

    parser.add_argument("--image_width", type=int, default=512)
    parser.add_argument("--image_height", type=int, default=512)
    parser.add_argument("--fps", type=float, default=30.0)
    parser.add_argument("--seed", type=int, default=4993)
    parser.add_argument("--num_inference_steps", type=int, default=30)
    parser.add_argument("--guidance_scale", type=float, default=3.5)
    parser.add_argument("--context_frames", type=int, default=12)
    parser.add_argument("--context_stride", type=int, default=1)
    parser.add_argument("--context_overlap", type=int, default=4)
    parser.add_argument("--n_motion_frames", type=int, default=4)
    parser.add_argument("--motion_scale", type=float, default=1.0)
    parser.add_argument("--text_attention_weight", default=None, type=float)
    parser.add_argument("--audio_attention_weight", default=None, type=float)
    parser.add_argument("--num_tokens", default=None, nargs='*', type=int)
    parser.add_argument("--ipa_scale", default=None, nargs='*', type=float)
    parser.add_argument("--t2i_adapter_conditioning_scale", default=1.0, type=float)
    parser.add_argument("--eta", default=0.0, type=float)
    parser.add_argument("--align_color_alpha", default=0.6, type=float)
    parser.add_argument("--invert_skip", default=50, type=int)
    parser.add_argument("--invert_noise_add_step", default=1, type=int)
    parser.add_argument("--invert_inference_steps", type=int, default=None)
    parser.add_argument(
        "--invert_ref", type=int, default=0, help="0: False, 1: True. If True, applying inversion to obtain the start code",
    )
    parser.add_argument(
        "--vae_mask_strategy", type=str, default=None, help="Choices: {randn_bg_repeat, weak_bg_inverted}",
    )

    # Module Loading Settings
    parser.add_argument(
        "--disable_motion", type=int, default=0, help="0: False, 1: True. If True, disable applying Motion module",
    )
    parser.add_argument(
        "--disable_audio", type=int, default=0, help="0: False, 1: True. If True, disable applying Audio module",
    )
    parser.add_argument(
        "--disable_ipa", action="store_true", help="disable applying IP-Adapter module"
    )
    parser.add_argument(
        "--disable_kps", action="store_true", help="disable applying KPS module"
    )
    parser.add_argument(
        "--apply_animatediff", action="store_true", help="apply AnimateDiff Motion Adapter",
    )
    parser.add_argument(
        "--apply_t2i_adapter", type=int, default=0,  help="enable applying T2I-Adapter module",
    )
    parser.add_argument(
        "--save_clip", type=int, default=0, help="0: False, 1: True. If True, save each clip",
    )

    # LoRA Settings
    parser.add_argument("--lora_scale", default=1.0, type=float)
    parser.add_argument("--lora_path", type=str, default=None)
    # FreeU Settings
    parser.add_argument("--b1", default=1.2, type=float)
    parser.add_argument("--b2", default=1.4, type=float)
    parser.add_argument("--s1", default=0.9, type=float)
    parser.add_argument("--s2", default=0.2, type=float)
    parser.add_argument("--threshold", default=1, type=int)
    # Dynamic Threshold Settings
    parser.add_argument('--mimic_scale', type=float, default=3.5, help='Mimic scale value for DynThresh')
    parser.add_argument('--threshold_percentile', type=float, default=0.9, help='Threshold percentile for DynThresh')
    parser.add_argument('--mimic_mode', type=str, default='Constant', help='Mimic mode for DynThresh')
    parser.add_argument('--mimic_scale_min', type=float, default=0.0, help='Minimum mimic scale value for DynThresh')
    parser.add_argument('--cfg_mode', type=str, default='Constant', help='CFG mode for DynThresh')
    parser.add_argument('--cfg_scale_min', type=float, default=0.0, help='Minimum CFG scale value for DynThresh')
    parser.add_argument('--sched_val', type=float, default=1.0, help='Schedule value for DynThresh')
    parser.add_argument('--experiment_mode', type=int, default=0, help='Experiment mode for DynThresh')
    parser.add_argument('--separate_feature_channels', type=int, default=1, help='1: Enable; 0: Disable; Separate feature channels for DynThresh')
    parser.add_argument('--scaling_startpoint', type=str, default='MEAN', help='Scaling start point for DynThresh')
    parser.add_argument('--variability_measure', type=str, default='STD', help='Variability measure for DynThresh')
    parser.add_argument('--interpolate_phi', type=float, default=1.0, help='Interpolation factor for DynThresh')
    # ControlNet Settings
    parser.add_argument("--controlnet_conditioning_scale", default=0.1, type=float)
    parser.add_argument("--ctrl_kps", type=int, default=0, help="0: False, 1: True. If True, use kps instead of face-masked image")
    # Hook Settings
    parser.add_argument(
        "--store_attn", type=int, default=0, help="0: False, 1: True. If True, save attn maps",
    )
    parser.add_argument('--store_attn_key', type=str, default='attn2')
    parser.add_argument(
        "--store_qk", type=int, default=0, help="0: False, 1: True. If True, save qk",
    )
    parser.add_argument('--store_qk_key', type=str, default='attn1')

    args = parser.parse_args()

    args.invert_ref = True if args.invert_ref == 1 else False
    args.disable_audio = True if args.disable_audio == 1 else False
    args.disable_motion = True if args.disable_motion == 1 else False
    args.save_clip = True if args.save_clip == 1 else False
    args.store_attn = True if args.store_attn == 1 else False
    args.store_qk = True if args.store_qk == 1 else False
    args.ctrl_kps = True if args.ctrl_kps == 1 else False
    args.apply_t2i_adapter = True if args.apply_t2i_adapter == 1 else False

    args.separate_feature_channels = True if args.separate_feature_channels == 1 else False
    if args.t2i_adapter_control_type is None:
        args.t2i_adapter_control_type = []

    if not args.disable_ipa:
        if args.num_tokens is None:
            raise ValueError('Number of IPA tokens required!')
        elif len(args.num_tokens) == 1:
            args.num_tokens = args.num_tokens[0]
        if args.ipa_scale is None:
            args.ipa_scale = 1.0
        elif len(args.ipa_scale) == 1:
            args.ipa_scale = args.ipa_scale[0]

    return args


def load_denoising_unet(args, inference_config, dtype, device):
    print_highlighted_block_log(
        title="Load Denoising UNet",
        message="""
            1. load main denoising unet;\n 
            2. Replicate Attn2 to Attn1_7;\n 
            3. Load Motion Module (Adapter)
        """,
        title_color="\033[1;31m",  # Bold red for title
        text_color="\033[1;33m",  # Bold yellow for message
    )
    # List of patterns to find and create copies with new names
    patterns_to_copy = [
        ("transformer_blocks.0.attn2", "transformer_blocks.0.attn1_7"),
        ("transformer_blocks.0.norm2", "transformer_blocks.0.norm1_7"),
    ]

    denoising_unet = UNet3DConditionModel.from_config_2d(
        args.unet_config_path,
        unet_additional_kwargs=inference_config.unet_additional_kwargs,
    ).to(dtype=dtype, device=device)

    # 1. load main denoising unet
    unet_state_dict = torch.load(args.denoising_unet_path, map_location="cpu")  #! update the state dict with the denoising unet weights
    print(f"[INFO] Loaded weights of Denoising U-Net from {args.denoising_unet_path}.")

    # 2. replicate the attn1_7 weights from the attn2 in original sd15
    if os.path.exists(os.path.join(args.sd_model_name, 'unet/diffusion_pytorch_model.bin')):
        sd_state_dict = torch.load(os.path.join(args.sd_model_name, 'unet/diffusion_pytorch_model.bin'), map_location="cpu",)
    else:
        sd_state_dict = unet_state_dict
    faceid_state_dict = {}
    # Iterate over the patterns and perform the copying
    for pattern, new_pattern in patterns_to_copy:
        # Find all keys containing the pattern
        keys_to_copy = [key for key in sd_state_dict.keys() if pattern in key]
        for key in keys_to_copy:
            new_key = key.replace(pattern, new_pattern)
            faceid_state_dict[new_key] = sd_state_dict[key].clone()  # Use .clone() to avoid referencing the same tensor
    # merge this with the original state_dict
    unet_state_dict.update(faceid_state_dict)  #! update the state dict with the attn1_7 by replicating attn2
    print(f"[INFO] Replicate the attn1_7 weights from the attn2 in pretrained model.")

    # 3. load motion module
    motion_module_path = Path(args.motion_module_path)
    if motion_module_path.exists() and motion_module_path.is_file():
        if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
            print(f"Load motion module params from {motion_module_path}")
            motion_state_dict = torch.load(motion_module_path, map_location="cpu", weights_only=True)
        elif motion_module_path.suffix.lower() == ".safetensors":
            motion_state_dict = load_file(motion_module_path, device="cpu")
        else:
            raise RuntimeError(f"unknown file format for motion module weights: {motion_module_path.suffix}")
        unet_state_dict.update(motion_state_dict)  #! update the state dict with motion modules
        print(f"Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.")
    params = [
        p.numel() if "temporal" in n else 0
        for n, p in denoising_unet.named_parameters()
    ]
    print(f"[INFO] Loaded {sum(params) / 1e6}M-parameter motion module")

    # 4. load the weights into the model
    m, u = denoising_unet.load_state_dict(unet_state_dict, strict=False)
    print(f"### missing keys: {m}; \n### unexpected keys: {u};")

    return denoising_unet


def load_v_kps_guider(v_kps_guider_path, dtype, device):
    v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
    v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
    print(f"[INFO] Loaded weights of V-Kps Guider from {v_kps_guider_path}.")
    return v_kps_guider


def load_t2i_adapter(t2i_adapter_model_path, dtype, device):
    def get_cond_ch(t2i_adapter_model_path):
        if 'sketch' in t2i_adapter_model_path or 'canny' in t2i_adapter_model_path:
            return 1
        return 3

    print_highlighted_block_log(
        title="Load T2I Adapter",
        message=f"Loading weights of T2I Adapter from {t2i_adapter_model_path}.",
        title_color="\033[1;31m",  # Bold red for title
        text_color="\033[1;33m",  # Bold yellow for message
    )
    if os.path.isdir(t2i_adapter_model_path):
        t2i_adapter = T2IAdapter.from_pretrained(t2i_adapter_model_path, torch_dtype=dtype).to(device)
    elif t2i_adapter_model_path.endswith(".pth"):
        t2i_adapter = T2IAdapter(
            in_channels=get_cond_ch(t2i_adapter_model_path),
            channels=[320, 640, 1280, 1280],
            num_res_blocks=2,
            downscale_factor=8,
            adapter_type="full_adapter",
        ).to(dtype=dtype, device=device)
        t2i_state_dict_ = torch.load(t2i_adapter_model_path, map_location="cpu")
        t2i_state_dict = t2i_adapter_map_keys(
            t2i_state_dict_
        )  # Create the new state dictionary with mapped keys
        t2i_adapter.load_state_dict(
            t2i_state_dict, strict=True
        )  # Load the renamed state dictionary into the model
    else:
        raise NotImplementedError("T2I Adapter checkpoint incorrect configuration!")
    print(f"[INFO] loaded T2I Adapter!")
    return t2i_adapter


def load_image_encoder(image_encoder_path, device):
    print_highlighted_block_log(
        title="Load Image Encoders",
        message=f"Loading weights of Image Encoder from {image_encoder_path}",
        title_color="\033[1;31m",  # Bold red for title
        text_color="\033[1;33m",  # Bold yellow for message
    )
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        image_encoder_path
    ).to(device, dtype=torch.float16)
    print(f"[INFO] loaded Image Encoders!")
    return image_encoder


def load_text_encoder(sd_model_name, device):
    print_highlighted_block_log(
        title="Load Text Encoders",
        message=f"Loading weights of Text Encoder from {sd_model_name}.",
        title_color="\033[1;31m",  # Bold red for title
        text_color="\033[1;33m",  # Bold yellow for message
    )
    tokenizer = CLIPTokenizer.from_pretrained(
        sd_model_name,
        subfolder="tokenizer",
        torch_dtype=torch.float16,
    )
    text_encoder = CLIPTextModel.from_pretrained(
        sd_model_name,
        subfolder="text_encoder",
        torch_dtype=torch.float16,
    ).to(device)
    print(f"[INFO] loaded Text Encoders!")
    return tokenizer, text_encoder


def load_audio_projection(
    audio_projection_path,
    dtype,
    device,
    inp_dim: int,
    mid_dim: int,
    out_dim: int,
    inp_seq_len: int,
    out_seq_len: int,
    aud_depth: int,
):
    print_highlighted_block_log(
        title="Load Audio Projection",
        message=f"Loading weights of Audio Projection from {audio_projection_path}.",
        title_color="\033[1;31m",  # Bold red for title
        text_color="\033[1;33m",  # Bold yellow for message
    )
    audio_projection = AudioProjection(
        dim=mid_dim,
        depth=aud_depth,
        dim_head=64,
        heads=12,
        num_queries=out_seq_len,
        embedding_dim=inp_dim,
        output_dim=out_dim,
        ff_mult=4,
        max_seq_len=inp_seq_len,
    ).to(dtype=dtype, device=device)
    audio_projection.load_state_dict(
        torch.load(audio_projection_path, map_location="cpu")
    )
    print(f"[INFO] Loaded weights of Audio Projection from {audio_projection_path}.")
    return audio_projection


def get_scheduler():
    inference_config_path = "././configs/infer/inference_v2.yaml"
    inference_config = OmegaConf.load(inference_config_path)
    scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
    if inference_config.sampler == 'DDIM':
        scheduler = DDIMScheduler(**scheduler_kwargs)
    elif inference_config.sampler == 'KDPM':
        scheduler = KDPM2DiscreteScheduler(**scheduler_kwargs)
    elif inference_config.sampler == 'PNDM':
        scheduler = PNDMScheduler(**scheduler_kwargs)

    return scheduler


def count_params(pipe):
    if pipe.unet:
        num_params, num_trainable_params = get_module_params(pipe.unet)
        print(
            f"#parameters of Denoising U-Net is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if pipe.v_kps_guider:
        num_params, num_trainable_params = get_module_params(pipe.v_kps_guider)
        print(
            f"#parameters of V-Kps Guider is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if pipe.adapter:
        num_params, num_trainable_params = get_module_params(pipe.adapter)
        print(
            f"#parameters of T2I-Adapter is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if pipe.audio_projection:
        num_params, num_trainable_params = get_module_params(pipe.audio_projection)
        print(
            f"#parameters of Audio Projection is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if pipe.image_encoder:
        num_params, num_trainable_params = get_module_params(pipe.image_encoder)
        print(
            f"#parameters of FaceID Image Encoder is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if pipe.text_encoder:
        num_params, num_trainable_params = get_module_params(pipe.text_encoder)
        print(
            f"#parameters of Text Encoder is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )


def prepare_face_adapter_inputs(
    src_im_pil, fa_app, net_d3dfr, bfm_facemodel, net_seg_res18, clip_image_processor, image_size=512, crop_ratio=0.81, dtype=torch.float16, device='cuda'
):
    # ===== insightface detect 5pts
    face_info = fa_app.get(cv2.cvtColor(np.array(src_im_pil), cv2.COLOR_RGB2BGR))
    try:
        face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
        dets = face_info['bbox']
    except:
        face_info = None

    if face_info is None:
        clip_input_src_tensors = clip_image_processor(images=src_im_pil, return_tensors="pt").pixel_values.view(-1, 3, 224, 224).to(device)
        return None, clip_input_src_tensors
    else:
        # ===== crop image
        bbox = dets[0:4]
        bbox_size = max(bbox[2]-bbox[0], bbox[2]-bbox[0])
        bbox_x = 0.5*(bbox[2]+bbox[0])
        bbox_y = 0.5*(bbox[3]+bbox[1])
        x1 = bbox_x-bbox_size*crop_ratio
        x2 = bbox_x+bbox_size*crop_ratio
        y1 = bbox_y-bbox_size*crop_ratio
        y2 = bbox_y+bbox_size*crop_ratio
        bbox_pts4 = np.array([[x1,y1],[x1,y2],[x2,y2],[x2,y1]], dtype=np.float32)   

        # ===== Affine Mapping Source Image
        warp_mat_crop = transformation_from_points(bbox_pts4, mean_box_lm4p_512)
        src_im_crop512 = cv2.warpAffine(np.array(src_im_pil), warp_mat_crop, (512, 512), flags=cv2.INTER_LINEAR)
        src_im_pil = Image.fromarray(src_im_crop512)
        face_info = fa_app.get(cv2.cvtColor(np.array(src_im_pil), cv2.COLOR_RGB2BGR))
        face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
        pts5 = face_info['kps']  
        warp_mat = get_affine_transform(pts5, mean_face_lm5p_256)
        image_src_warpmat256 = warp_mat.reshape((1, 2, 3))
        src_im_crop256 = cv2.warpAffine(np.array(src_im_pil), warp_mat, (256, 256), flags=cv2.INTER_LINEAR)

        # ====== Obtain LandMarks
        src_im_crop256_pil = Image.fromarray(src_im_crop256)
        image_src_crop256 = pil2tensor(src_im_crop256_pil).view(1,3,256,256).to(device)
        images_src = pil2tensor(src_im_pil).view(1, 3, image_size, image_size).to(device)
        clip_input_src_tensors = clip_image_processor(images=src_im_pil, return_tensors="pt").pixel_values.view(-1, 3, 224, 224).to(device)
        src_d3d_coeff = net_d3dfr(image_src_crop256)
        src_pts68 = bfm_facemodel.get_lm68(src_d3d_coeff)
        im_pts70 = draw_pts70_batch(src_pts68, src_d3d_coeff[:, 257:], image_src_warpmat256, image_size, return_pt=True)
        # src_pts68 = bfm_facemodel.get_lm3(src_d3d_coeff) # TEST: Landmark 3
        # im_pts70 = draw_pts3_batch(src_pts68, src_d3d_coeff[:, 257:], image_src_warpmat256, image_size, return_pt=True)
        im_pts70 = im_pts70.to(images_src)

        # ====== ControlNet Images
        face_masks_src = (net_seg_res18(torch.cat([images_src, im_pts70], dim=1))>0.5).float()
        controlnet_image = im_pts70*face_masks_src + images_src*(1-face_masks_src)  #  replace ori src face with landmarks 
        controlnet_image = controlnet_image.to(dtype=dtype)

        return controlnet_image, clip_input_src_tensors


def main():
    """
    # ======================================================
    # ============ 1. inference configurations ============
    # ======================================================
    """
    args = parse_args()
    device = torch.device(
        f"{args.device}:{args.gpu_id}" if args.device == "cuda" else args.device
    )
    dtype = torch.float16 if args.dtype == "fp16" else torch.float32

    inference_config_path = "././configs/infer/inference_v2.yaml"
    inference_config = OmegaConf.load(inference_config_path)
    if args.disable_motion:
        inference_config.unet_additional_kwargs.use_motion_module = False
    if args.text_attention_weight is not None:
        inference_config.unet_additional_kwargs.text_attention_weight = (
            args.text_attention_weight
        )
    if args.audio_attention_weight is not None:
        inference_config.unet_additional_kwargs.audio_attention_weight = (
            args.audio_attention_weight
        )
    if args.motion_scale is not None:
        inference_config.unet_additional_kwargs.motion_module_kwargs.motion_scale = (
            args.motion_scale
        )

    """
    # ======================================================
    # ============ 2. load modules for inference pipeline ============
    # ======================================================
    """
    vae_path = args.vae_path
    audio_encoder_path = args.audio_encoder_path
    v_kps_guider_path = args.v_kps_guider_path
    audio_projection_path = args.audio_projection_path
    image_encoder_path = args.image_encoder_path

    # 2.1 InsightFace App for Keypoints Extraction
    FaceAPP_H, FaceAPP_W = 512, 512
    app = FaceAnalysis(
        providers=[
            "CUDAExecutionProvider" if args.device == "cuda" else "CPUExecutionProvider"
        ],
        provider_options=[{"device_id": args.gpu_id}] if args.device == "cuda" else [],
        root=args.insightface_model_path,
    )
    app.prepare(ctx_id=0, det_size=(FaceAPP_H, FaceAPP_W))

    # 2.2 Diffusion Model
    vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device)
    scheduler = get_scheduler()
    scheduler.set_timesteps(args.num_inference_steps)
    denoising_unet = load_denoising_unet(args, inference_config, dtype, device)
    print_highlighted_block_log(
        title="Scheduler Config",
        message=OmegaConf.to_container(inference_config.noise_scheduler_kwargs),
    )

    # 2.3 Audio Projection and Encoder
    if not args.disable_audio:
        audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(
            dtype=dtype, device=device
        )
        audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path)
        audio_projection = load_audio_projection(
            audio_projection_path,
            dtype,
            device,
            inp_dim=denoising_unet.config.cross_attention_dim,
            mid_dim=denoising_unet.config.cross_attention_dim,
            out_dim=denoising_unet.config.cross_attention_dim,
            inp_seq_len=2 * (2 * args.num_pad_audio_frames + 1),
            out_seq_len=2 * args.num_pad_audio_frames + 1,
            aud_depth=args.aud_depth,
        )
    else:
        audio_encoder, audio_processor, audio_projection = None, None, None

    # 2.4 Text and Image Encoders
    if not args.disable_ipa:
        image_encoder = load_image_encoder(image_encoder_path, device)
    else:
        image_encoder = None

    tokenizer, text_encoder = load_text_encoder(args.sd_model_name, device)

    # 2.5 Residual Condition Encoder
    if args.apply_t2i_adapter:
        t2i_adapter = load_t2i_adapter(args.t2i_adapter_model_path, dtype, device)
    else:
        t2i_adapter = None
    if not args.disable_kps:
        v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device)
    else:
        v_kps_guider = None

    # 2.6 Load Face Adapter Models
    controlnet = ControlNetModel.from_pretrained(os.path.join(args.face_checkpoint, 'controlnet'), torch_dtype=dtype).to(device)
    # Landmark Model
    net_d3dfr = model_resnet_d3dfr.getd3dfr_res50(os.path.join(args.face_checkpoint, 'third_party/d3dfr_res50_nofc.pth')).eval().to(device)
    bfm_facemodel = bfm.BFM(focal=1015*256/224, image_size=256, bfm_model_path=os.path.join(args.face_checkpoint, 'third_party/BFM_model_front.mat')).to(device)
    # Vision Encoder
    net_vision_encoder = CLIPVisionModel.from_pretrained(os.path.join(args.face_checkpoint, 'vision_encoder')).to(device)
    net_image2token = Image2Token(visual_hidden_size=net_vision_encoder.vision_model.config.hidden_size, text_hidden_size=768, max_length=77, num_layers=3).to(device)
    net_image2token.load_state_dict(torch.load(os.path.join(args.face_checkpoint, 'net_image2token.pth')))
    clip_image_processor = CLIPImageProcessor()
    # ID Encoder
    net_id2token = ID2Token(id_dim=512, text_hidden_size=768, max_length=77, num_layers=3).to(device)
    net_id2token.load_state_dict(torch.load(os.path.join(args.face_checkpoint, 'net_id2token.pth')))
    # Seg Model and FaceAnalysis
    net_seg_res18 = model_seg_unet.UNet().eval().to(device)
    net_seg_res18.load_state_dict(torch.load(os.path.join(args.face_checkpoint, 'net_seg_res18.pth')))
    fa_app = FaceAnalysis(name='antelopev2', root=os.path.join(args.face_checkpoint, 'third_party'), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
    fa_app.prepare(ctx_id=0, det_size=(640, 640))

    """
    # ======================================================
    # ============ 3. initialize inference pipeline ============
    # ======================================================
    """
    vae.requires_grad_(False)
    denoising_unet.requires_grad_(False)
    if image_encoder:
        image_encoder.requires_grad_(False)
    if text_encoder:
        text_encoder.requires_grad_(False)
    if audio_encoder and audio_projection:
        audio_encoder.requires_grad_(False)
        audio_projection.requires_grad_(False)
    if v_kps_guider:
        v_kps_guider.requires_grad_(False)
    if t2i_adapter:
        t2i_adapter.requires_grad_(False)

    generator = torch.manual_seed(args.seed)
    lora_setting = json.load(open(args.lora_setting_path, "r"))

    pipeline = VExpressPipelinePrefixMeanVarFace(
        vae=vae,
        unet=denoising_unet,
        v_kps_guider=v_kps_guider,
        audio_processor=audio_processor,
        audio_encoder=audio_encoder,
        audio_projection=audio_projection,
        scheduler=scheduler,
        image_encoder=image_encoder,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        face_analysis_app=app,
        ip_ckpt=args.ip_ckpt,
        ip_mode=args.ip_mode,
        num_tokens=args.num_tokens,
        adapter=t2i_adapter,
        lora_path=args.lora_path,
        lora_scale=args.lora_scale,
        controlnet=controlnet,
        extra_lora_setting=lora_setting,
        store_attn=args.store_attn,
        store_attn_key=args.store_attn_key,
        store_qk=args.store_qk,
        store_qk_key=args.store_qk_key
    ).to(dtype=dtype, device=device)

    if args.dreambooth_path != "" and args.dreambooth_path is not None:
        pipeline = load_dreambooth_weights(pipeline, args.dreambooth_path, dtype, device)

    # count parameters and check zero-initilization
    count_params(pipeline)
    check_zero_initialization(pipeline.unet, "temporal_transformer.proj_out", logger=None)

    if args.store_attn:
        pipeline.unet = register_attention_hook(pipeline.unet, attn_key=args.store_attn_key)
    if args.store_qk:
        pipeline.unet = register_attention_hook(pipeline.unet, attn_key=args.store_qk_key)
    """
    # ======================================================
    # ============ 4. data preprocessing ============
    # ======================================================
    """
    # 4.1 extract face embedding from reference image for IP-Adapter
    reference_image = Image.open(args.reference_image_path)
    reference_image = reference_image.resize((FaceAPP_H, FaceAPP_W))
    ref_image_cv2 = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR)
    faces = app.get(ref_image_cv2)
    print(f"Number of faces: {len(faces)}")
    try:
        face_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
    except:
        raise NotImplementedError("failed to capture face information")

    # 4.2 audio processing
    _, audio_waveform, meta_info = torchvision.io.read_video(
        args.audio_path, pts_unit="sec"
    )
    audio_sampling_rate = meta_info["audio_fps"]
    if audio_sampling_rate != args.standard_audio_sampling_rate:
        audio_waveform = torchaudio.functional.resample(
            audio_waveform,
            orig_freq=audio_sampling_rate,
            new_freq=args.standard_audio_sampling_rate,
        )
    audio_waveform = audio_waveform.mean(dim=0)

    duration = audio_waveform.shape[0] / args.standard_audio_sampling_rate
    video_length = int(duration * args.fps)

    num_contexts = compute_num_context(
        video_length, args.context_frames + args.n_motion_frames, args.n_motion_frames
    )
    context_indices = compute_context_indices(
        num_context=num_contexts,
        context_size=args.context_frames + args.n_motion_frames,
        context_overlap=args.n_motion_frames,
    )
    video_length = context_indices[-1][1] + 1
    args.fps = video_length / duration

    # 4.3 residual conditions
    # keypoints prior
    if not args.disable_kps or ("kps" in args.t2i_adapter_control_type) or ("openpose" in args.t2i_adapter_control_type):
        reference_kps = faces[0].kps[:3]
        if args.kps_path.endswith(".pth"):
            kps_images = extract_kps_img(
                ref_image_cv2, reference_kps, video_len=video_length, kps_path=args.kps_path, retarget_strategy=args.retarget_strategy,
            )
            point_kps_images = extract_kps_img(
                ref_image_cv2, reference_kps, video_len=video_length, kps_path=args.kps_path, retarget_strategy=args.retarget_strategy, stick_width=1
            )
        elif args.kps_path.endswith(".mp4"):
            kps_images = video_to_pil_images(args.kps_path)
            # TODO: if kps is from a video, require interpolation for align the length of kps images as the video_length
            kps_images = kps_images[:video_length]
    else:
        point_kps_images = None
        kps_images = None

    # 4.4 face adapter inputs
    controlnet_image, clip_input_src_tensors = prepare_face_adapter_inputs(
        reference_image,
        fa_app,
        net_d3dfr,
        bfm_facemodel,
        net_seg_res18,
        clip_image_processor,
        image_size=args.image_height,
        crop_ratio=args.crop_ratio,
        dtype=dtype,
    )

    last_hidden_state = net_vision_encoder(clip_input_src_tensors).last_hidden_state
    controlnet_encoder_hidden_states_src = net_image2token(last_hidden_state).to(dtype=dtype)
    empty_prompt_token = torch.load('/root/Code/Personalized/Face-Adapter/empty_prompt_embedding.pth').view(1, 77,768).to(dtype=dtype).to(device)

    print_highlighted_block_log(
        title="Prepare Data",
        message=f"""
        1.Loading reference image from {args.reference_image_path}; BG Image {args.refbg_image_path}\n
        2.Extrating Face Info and Face Emebds\n
        3.Loading Audio from {args.audio_path}\n
        Length of audio is {audio_waveform.shape[0]} with the sampling rate of {audio_sampling_rate}\n
        The corresponding video length is {video_length} with FPS {args.fps}\n
        4.Load T2I-Adapter Conditions -- {args.t2i_adapter_control_type}""",
    )

    """
    # ======================================================
    #! ============ 5. INFERENCE START ! ============
    # ======================================================
    """
    vae_scale_factor = 8
    latent_height = args.image_height // vae_scale_factor
    latent_width = args.image_width // vae_scale_factor

    print_highlighted_block_log(
        title="Inference Start!!",
        message=f"""
        num_inference_steps: {args.num_inference_steps}; Guidance Scale: {args.guidance_scale};\n
        Disable Motion: {args.disable_motion}; Disable Audio: {args.disable_audio}
        Disable KPS: {args.disable_kps}; Disable IPA: {args.disable_ipa}\n
        Text Attention Weight: {inference_config.unet_additional_kwargs.text_attention_weight}, Audio Attention Weight: {inference_config.unet_additional_kwargs.audio_attention_weight}
        Context Frames: {args.context_frames}; Context Overlap: {args.context_overlap}\n
        Inversion: {args.invert_ref}; Invert Skip: {args.invert_skip}/{args.invert_inference_steps}; Noise Add Step: {args.invert_noise_add_step}\n
        """,
    )

    latent_shape = (1, 4, video_length, latent_height, latent_width)
    vae_latents = randn_tensor(
        latent_shape, generator=generator, device=device, dtype=dtype
    )

    # prepare extra cross-attention kwargs
    cross_attention_kwargs = {}
    if args.n_motion_frames:
        cross_attention_kwargs["n_motion_frames"] = args.n_motion_frames

    # Dynamic Threshold Setting Parameters
    dynthresh_kwargs = {
        "mimic_scale": args.mimic_scale,
        "threshold_percentile": args.threshold_percentile,
        "mimic_mode": args.mimic_mode,
        "mimic_scale_min": args.mimic_scale_min,
        "cfg_mode": args.cfg_mode,
        "cfg_scale_min": args.cfg_scale_min,
        "sched_val": args.sched_val,
        "experiment_mode": args.experiment_mode,
        "max_steps": args.num_inference_steps,
        "separate_feature_channels": args.separate_feature_channels,
        "scaling_startpoint": args.scaling_startpoint,
        "variability_measure": args.variability_measure,
        "interpolate_phi": args.interpolate_phi,
    }

    pipe_res = pipeline(
        vae_latents=vae_latents,
        kps_images=kps_images,
        reference_image_path=args.reference_image_path,
        refbg_image_path=args.refbg_image_path,
        audio_waveform=audio_waveform,
        prompt=args.prompt,
        negative_prompt=args.negative_prompt,
        width=args.image_width,
        height=args.image_height,
        video_length=video_length,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        context_frames=args.context_frames,
        context_stride=args.context_stride,
        context_overlap=args.context_overlap,
        num_pad_audio_frames=args.num_pad_audio_frames,
        generator=generator,
        face_embeds=face_embeds,
        face_info=faces,
        ipa_scale=args.ipa_scale,
        cross_attention_kwargs=cross_attention_kwargs,
        t2i_adapter_control_type=args.t2i_adapter_control_type,
        t2i_adapter_conditioning_scale=args.t2i_adapter_conditioning_scale,
        strength=1.0,  # TEST
        eta=args.eta,
        align_color_alpha=args.align_color_alpha,
        b1=args.b1,
        b2=args.b2,
        s1=args.s1,
        s2=args.s2,
        threshold=args.threshold,
        dynthresh_kwargs=dynthresh_kwargs,
        controlnet_image=controlnet_image,
        controlnet_prompt_embeds=controlnet_encoder_hidden_states_src,
        controlnet_negative_prompt_embeds=empty_prompt_token,
        controlnet_conditioning_scale=args.controlnet_conditioning_scale,
        point_kps_images=point_kps_images,
        ctrl_kps=args.ctrl_kps,
    )

    video_latents = pipe_res.video_latents
    video_tensor = pipeline.decode_latents(video_latents)
    if isinstance(video_tensor, np.ndarray):
        video_tensor = torch.from_numpy(video_tensor)
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    save_video(video_tensor, args.output_path, audio_path=args.audio_path, fps=args.fps)
    print(f"The generated video has been saved at {args.output_path}.")

    # if args.store_attn or args.store_qk:
    #     attn_maps = get_net_attn_map((args.image_height, args.image_width))
    #     qs, ks = get_net_qk()
    #     import pdb; pdb.set_trace()

    if args.save_clip:
        all_latents = pipe_res.all_latents
        video_prefix_res = []
        for idx, latents_w_prefix in enumerate(all_latents):
            latents_w_prefix = pipeline.decode_latents(latents_w_prefix)
            if isinstance(latents_w_prefix, np.ndarray):
                latents_w_prefix = torch.from_numpy(latents_w_prefix)
            video_prefix_res.append(latents_w_prefix)
            output_path_basename = os.path.basename(args.output_path)
            output_path_dirname = os.path.dirname(args.output_path)
            save_dir = os.path.join(output_path_dirname, 'prefix_clips')
            os.makedirs(save_dir, exist_ok=True)
            save_basename = f"{output_path_basename.replace('.mp4', '')}-clip_{idx}.mp4"
            save_video(latents_w_prefix, os.path.join(save_dir, save_basename), fps=10.0)
            print(f"The clip video {idx} has been saved at {os.path.join(save_dir, save_basename)}.")


if __name__ == "__main__":
    main()
