'''

'''


import argparse
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import List

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

import numpy as np
import torch
import torchvision
from tqdm import tqdm
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from einops import repeat
from omegaconf import OmegaConf
from PIL import Image
import cv2
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection
import warnings

from configs.prompts.test_cases import TestCasesDict
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
# from src.pipelines.pipeline_parsing2vid_long import Parsing2VideoPipeline
from src.pipelines.pipeline_parsing2vid_long_test import Parsing2VideoPipeline
from src.utils.util import get_fps, read_frames, save_videos_grid, save_videos_grid_cv
from src.models.adapter import VAEAdatper
from FaRL import facer


parsing_name = 'parsing.png'
panc_name = 'parsing_align_no_contour'
cropped_ref_name = 'reference_cropped.png'
cropped_tgt_name = 'target_cropped'
aligned_folder = 'aligned'
# Clip+cvHBji1HgJ8+P0+C1+F5208-5344_Clip+d0isy9xi5rQ+P0+C1+F12819-12995
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default='./configs/inference/inference_stage2.yaml')
    parser.add_argument("--reference_img", default='images/others_new/002.jpg')
    parser.add_argument("--target_video", default='data/videos/test_camera/004.mp4')
    # parser.add_argument("--reference_img", default='data/jpgs/Clip+ecZ8IuIHjgE+P1+C0+F80466-80676/00000000.jpg')
    # parser.add_argument("--target_video", default='data/videos/test_large/test.mp4')
    # parser.add_argument("--reference_img", default='data/jpgs/Clip+ctTlUlOBXpo+P0+C1+F550-762/00000000.jpg')
    # parser.add_argument("--target_video", default='data/videos/test/Clip+ctTlUlOBXpo+P0+C1+F550-762.mp4')
    # parser.add_argument("--reference_img", default='data/jpgs/Clip+dBPBI71L1lk+P1+C0+F14450-14656/00000057.jpg')
    # parser.add_argument("--target_video", default='data/videos/test/Clip+dBPBI71L1lk+P1+C0+F14450-14656.mp4')
    parser.add_argument("--save_dir", default=None)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-L", type=int, default=24)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--cfg", type=float, default=3.5)
    parser.add_argument("--steps", type=int, default=30)
    parser.add_argument("--fps", type=int)
    parser.add_argument("--crop", action='store_true', default=True)
    parser.add_argument("--align", action='store_true', default=True)
    parser.add_argument("--width", type=int, default=512)
    parser.add_argument("--height", type=int, default=512)
    args = parser.parse_args()

    return args

def only_one_face(faces):
    if faces['rects'].size(0) == 1:
        return faces
    max = 0
    max_id = -1
    for i in range(faces['rects'].size(0)):
        x1, y1, x2, y2 = faces['rects'][i]
        # print(x1, y1, x2, y2)
        area = (x2 - x1) * (y2 - y1)
        if area > max:
            max = area
            max_id = i
    for key, value in faces.items():
        faces[key] = value[max_id].unsqueeze(0)

def get_reference_parsing(jpg_path, save_folder, face_detector, face_parser, face_aligner, device):
    print('get reference parsings')

    image = facer.hwc2bchw(facer.read_hwc(jpg_path)).to(device=device)  # image: 1 x 3 x h x w
    with torch.inference_mode():
        faces = face_detector(image)
    if faces['rects'].size(0) == 0:
        warnings.warn(f'{jpg_path} has no face', Warning)
        return None
        # raise ValueError(f'{jpg_path} has no face')
    only_one_face(faces)
    with torch.inference_mode():
        parsings = face_parser(image, faces)
        alignments = face_aligner(image, faces)
    seg_logits = parsings['seg']['logits']
    seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
    n_classes = seg_probs.size(1)
    vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
    vis_img = vis_seg_probs.sum(0, keepdim=True)
    # save parsing
    parse_img = facer.get_bhw(vis_img)
    pimage = Image.fromarray(parse_img.cpu().numpy()).resize((512, 512))
    pimage.save(os.path.join(save_folder, parsing_name))
    torch.cuda.empty_cache()
    return faces

def get_target_parsing(jpg_path, face_detector, device):
    # image = Image.open(jpg_path)
    # np_image = np.array(image.convert('RGB'))
    # if np_image.shape[0] > h and np_image.shape[1] > w:
    #     scale= min(np_image.shape[0] / h, np_image.shape[1] / w)
    #     image = image.resize((int(np_image.shape[1] / scale), int(np_image.shape[0] / scale)))
    #     np_image = np.array(image.convert('RGB'))
    #     torch_image = torch.from_numpy(np_image)
    # else:
    #     torch_image = facer.read_hwc(jpg_path)
    image = facer.hwc2bchw(facer.read_hwc(jpg_path)).to(device=device)  # image: 1 x 3 x h x w
    with torch.inference_mode():
        faces = face_detector(image)
    if faces['rects'].size(0) == 0:
        # warnings.warn(f'{jpg_path} has no face', Warning)
        return None
    only_one_face(faces)
    # with torch.inference_mode():
    #     parsings = face_parser(image, faces)
    #     alignments = face_aligner(image, faces)
    # seg_logits = parsings['seg']['logits']
    # seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
    # n_classes = seg_probs.size(1)
    # vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
    # vis_img = vis_seg_probs.sum(0, keepdim=True)
    # parse_img = facer.get_bhw_no_contour(vis_img)
    # img = parse_img
    # for pts in alignments['alignment']:
    #     # 之前的不用 color, 新训练的需要
    #     img = facer.draw_landmarks_only_eyes(img, None, pts.cpu().numpy(), color=(105, 105, 105))
    # pimage = Image.fromarray(img)
    # os.makedirs(os.path.join(save_folder, panc_name), exist_ok=True)
    # pimage.save(os.path.join(save_folder, panc_name, str(idx).zfill(4) + '.png'))
    # torch.cuda.empty_cache()
    return faces

def read_frames(video_path, save_folder, limit=10000000000):
    os.makedirs(save_folder, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    idx = 0
    is_limited = False
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        save_path = os.path.join(save_folder, str(idx).zfill(4) + '.png')
        cv2.imwrite(save_path, frame)
        idx += 1
        if idx >= limit:
            is_limited = True
            break
    cap.release()
    limit_str = f'limited by {limit}' if is_limited else 'full frames'
    print(f'read {idx} frames, {limit_str}')
    return fps, height, width

def get_parsing_4_crop(image_ori, face_detector, device):
    '''
    image = Image.open(path)
    np_image = np.array(image.convert('RGB'))
    return torch.from_numpy(np_image)
    '''
    image = facer.hwc2bchw(image_ori).to(device=device)  # image: 1 x 3 x h x w
    with torch.inference_mode():
        faces = face_detector(image)
    # print(faces)
    if faces['rects'].size(0) == 0:
        print('detected no faces, resize directly')
        return None
    only_one_face(faces)
    return faces

def crop(img_path, width=512, height=512):
    img = Image.open(img_path)
    img = img.resize((width, height))
    return img
    # process img
    # img = Image.open(img_path)
    # img = np.array(img.convert('RGB'))
    # img_width, img_height = img.shape[0], img.shape[1]
    # # 图片宽高不相等, 检测人脸区域, 切割后 resize
    # if img.shape[0] != img.shape[1]:
    #     img = torch.from_numpy(img)
    #     img_faces = get_parsing_4_crop(img, face_detector, device)
    #     if img_faces == None:
    #         # 没有检测到, 直接 resize
    #         print('reference img detect no faces, resize directly')
    #         img = Image.fromarray(np.array(img))
    #         img = img.resize((width, height))
    #     else:
    #         # 检测到了, 先切割再 resize
    #         rects = img_faces['rects'][0]
    #         _x1, _y1, _x2, _y2 = rects
    #         _x1 = torch.ceil(_x1).to(torch.int).item()
    #         _y1 = torch.ceil(_y1).to(torch.int).item()
    #         _x2 = torch.floor(_x2).to(torch.int).item()
    #         _y2 = torch.floor(_y2).to(torch.int).item()
    #         # 面部向外扩充的距离
    #         margin = min(_x1, _y1, img_width-_x2, img_height-_y2, _y2-_y1)
    #         _x1 -= margin
    #         _y1 -= margin
    #         _x2 += margin
    #         _y2 += margin
    #         cropped_width = _x2 - _x1
    #         cropped_height = _y2 - _y1
    #         if cropped_width > cropped_height:
    #             difference = cropped_width - cropped_height
    #             left = int(difference // 2)
    #             right = difference - left
    #             _x1 += left
    #             _x2 -= right
    #         elif cropped_height > cropped_width:
    #             difference = cropped_height - cropped_width
    #             top = int(difference // 2)
    #             bottom = difference - top
    #             _y1 += top
    #             _y2 -= bottom
    #         else:
    #             pass
    #         img = np.array(img)
    #         img = img[_y1:_y2, _x1:_x2, :]
    #         img = Image.fromarray(img)
    #         img = img.resize((width, height))
    # # 图片宽高相等, 直接 resize
    # else:
    #     img = Image.fromarray(img)
    #     img = img.resize((width, height))

    # process video
    # x1, y1, x2, y2 = -1, -1, -1, -1
    # reader = cv2.VideoCapture(video_path)
    # if not reader.isOpened():
    #     raise ValueError('{} open failed'.format(video_path))
    # frame_width, frame_height = int(reader.get(cv2.CAP_PROP_FRAME_WIDTH)), int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
    # fps = int(reader.get(cv2.CAP_PROP_FPS))
    # frames = []
    # while True:
    #     ret, frame = reader.read()
    #     if not ret:
    #         break
    #     frames.append(frame)
    #     frame = torch.from_numpy(frame)
    #     frame_faces = get_parsing_4_crop(frame, face_detector, device)
    #     if frame_faces == None:
    #         continue
    #     rects = frame_faces['rects'][0]
    #     _x1, _y1, _x2, _y2 = rects
    #     _x1 = torch.ceil(_x1).to(torch.int).item()
    #     _y1 = torch.ceil(_y1).to(torch.int).item()
    #     _x2 = torch.floor(_x2).to(torch.int).item()
    #     _y2 = torch.floor(_y2).to(torch.int).item()
    #     x1 = _x1 if x1 == -1 else min(x1, _x1)
    #     y1 = _y1 if y1 == -1 else min(y1, _y1)
    #     x2 = _x2 if x2 == -1 else max(x2, _x2)
    #     y2 = _y2 if y2 == -1 else max(y2, _y2)
    # reader.release()
    # cropped_frames = []
    # # 视频宽高不相等, 检测人脸区域, 切割后 resize
    # if frame_width != frame_height:
    #     margin = min(x1, y1, frame_width-x2, frame_height-y2, y2-y1)
    #     x1 -= margin
    #     y1 -= margin
    #     x2 += margin
    #     y2 += margin
    #     cropped_width = x2 - x1
    #     cropped_height = y2 - y1
    #     if cropped_width > cropped_height:
    #         difference = cropped_width - cropped_height
    #         left = int(difference // 2)
    #         right = difference - left
    #         x1 += left
    #         x2 -= right
    #     elif cropped_height > cropped_width:
    #         difference = cropped_height - cropped_width
    #         top = int(difference // 2)
    #         bottom = difference - top
    #         y1 += top
    #         y2 -= bottom
    #     else:
    #         pass
    #     for frame in frames:
    #         frame = frame[y1:y2, x1:x2, :]
    #         frame = Image.fromarray(frame)
    #         frame = frame.resize((width, height))
    #         cropped_frames.append(frame)
    # # 视频宽高相等, 直接 resize
    # else:
    #     for frame in frames:
    #         frame = Image.fromarray(frame)
    #         frame = frame.resize((width, height))
    #         cropped_frames.append(frame)
    # # print(len(frames))
    # return img, cropped_frames, fps

def align(faces_ref, faces_tgt, frames_tgt, tgt_shape, save_folder, face_detector, face_parser, face_aligner, w=512, h=512, device='cuda'):
    points = faces_ref['points']
    x_coords = points[0, :, 0]
    y_coords = points[0, :, 1]
    min_x_r = int(torch.min(x_coords).item())
    max_x_r = int(torch.max(x_coords).item())
    min_y_r = int(torch.min(y_coords).item())
    max_y_r = int(torch.max(y_coords).item())

    x1_r, y1_r, x2_r, y2_r = [int(v) for v in faces_ref['rects'][0]]
    y1_r, y2_r = [max(min(v, h-1), 0) for v in [y1_r, y2_r]]
    x1_r, x2_r = [max(min(v, w-1), 0) for v in [x1_r, x2_r]]

    max_w = 0
    max_h = 0
    x1_t, y1_t, x2_t, y2_t = -1, -1, -1, -1
    for i in range(len(faces_tgt)):
        faces = faces_tgt[i]
        points = faces['points']
        x_coords = points[0, :, 0]
        y_coords = points[0, :, 1]
        min_x_t = int(torch.min(x_coords).item())
        max_x_t = int(torch.max(x_coords).item())
        min_y_t = int(torch.min(y_coords).item())
        max_y_t = int(torch.max(y_coords).item())
        max_w = max(max_w, max_x_t - min_x_t)
        max_h = max(max_h, max_y_t - min_y_t)

        x1, y1, x2, y2 = [int(v) for v in faces['rects'][0]]
        y1, y2 = [max(min(v, tgt_shape[0]-1), 0) for v in [y1, y2]]
        x1, x2 = [max(min(v, tgt_shape[1]-1), 0) for v in [x1, x2]]
        x1_t = x1 if x1_t == -1 else min(x1_t, x1)
        y1_t = y1 if y1_t == -1 else min(y1_t, y1)
        x2_t = x2 if x2_t == -1 else max(x2_t, x2)
        y2_t = y1 if y2_t == -1 else max(y2_t, y2)

    scale = min((max_x_r - min_x_r) / max_w, (max_y_r - min_y_r) / max_h)
    center_x = int(x1_r + x2_r) / 2
    center_y = int(y1_r + y2_r) / 2

    print(x1_t, y1_t, x2_t, y2_t)
    print(scale)

    save_aligned_folder = os.path.join(save_folder, aligned_folder)
    os.makedirs(save_aligned_folder, exist_ok=True)
    saved_tgt_folder = os.path.join(save_folder, cropped_tgt_name)
    os.makedirs(saved_tgt_folder, exist_ok=True)
    for i, frame in enumerate(frames_tgt):
        frame = np.array(frame)
        img_face_region = frame[y1_t:y2_t, x1_t:x2_t]
        img_face_region = np.array(Image.fromarray(img_face_region).resize((int((x2_t - x1_t) * scale), int((y2_t - y1_t) * scale))))
        new_h, new_w = img_face_region.shape[0], img_face_region.shape[1]
        new_x1 = int(max(0, center_x - new_w // 2))
        new_y1 = int(max(0, center_y - new_h // 2))
        new_x2 = int(new_x1 + new_w)
        new_y2 = int(new_y1 + new_h)

        new_img = np.zeros((h * 2, w * 2, 3))
        new_img[new_y1:new_y2, new_x1:new_x2] = img_face_region
        new_img = new_img[:h, :w, :]
        Image.fromarray(new_img.astype(np.uint8)).save(os.path.join(saved_tgt_folder, str(i).zfill(4) + '.png'))
        image = facer.hwc2bchw(torch.from_numpy(new_img)).to(device=device)  # image: 1 x 3 x h x w
        with torch.inference_mode():
            faces = face_detector(image)
        if faces['rects'].size(0) == 0:
            # warnings.warn(f'aligned {i} has no face', Warning)
            # os.remove(os.path.join(saved_tgt_folder, str(i).zfill(4) + '.png'))
            continue
        only_one_face(faces)
        with torch.inference_mode():
            parsings = face_parser(image, faces)
            alignments = face_aligner(image, faces)
        seg_logits = parsings['seg']['logits']
        seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
        n_classes = seg_probs.size(1)
        vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
        vis_img = vis_seg_probs.sum(0, keepdim=True)
        parse_img = facer.get_bhw_no_contour(vis_img)
        img = parse_img
        for pts in alignments['alignment']:
            # 之前的不用 color, 新训练的需要
            img = facer.draw_landmarks_only_eyes(img, None, pts.cpu().numpy(), color=(105, 105, 105))
        pimage = Image.fromarray(img)
        pimage.save(os.path.join(save_aligned_folder, str(i).zfill(4) + '.png'))
        torch.cuda.empty_cache()

def main():
    args = parse_args()
    config = OmegaConf.load(args.config)

    if config.weight_dtype == "fp16":
        weight_dtype = torch.float16
    else:
        weight_dtype = torch.float32

    vae = AutoencoderKL.from_pretrained(
        config.pretrained_vae_path,
    ).to("cuda", dtype=weight_dtype)

    reference_unet = UNet2DConditionModel.from_pretrained(
        config.pretrained_base_model_path,
        subfolder="unet",
    ).to(dtype=weight_dtype, device="cuda")

    denoising_unet = UNet3DConditionModel.from_pretrained_2d(
        config.pretrained_base_model_path,
        config.motion_module_path,
        subfolder="unet",
        unet_additional_kwargs=config.unet_additional_kwargs,
    ).to(dtype=weight_dtype, device="cuda")

    pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
        dtype=weight_dtype, device="cuda"
    )

    # 使用 pipeline_parsing2vid_adapter.Parsing2VideoPipeline 时需要 adapter
    # adapter = VAEAdatper(config.pretrained_vae_path, weight_dtype).to(
    #     dtype=weight_dtype, device="cuda"
    # )

    image_enc = CLIPVisionModelWithProjection.from_pretrained(
        config.image_encoder_path
    ).to(dtype=weight_dtype, device="cuda")

    sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
    scheduler = DDIMScheduler(**sched_kwargs)

    generator = torch.manual_seed(args.seed)

    width, height = args.W, args.H

    # load pretrained weights
    denoising_unet.load_state_dict(
        torch.load(config.denoising_unet_path, map_location="cpu"),
        strict=False,
    )
    reference_unet.load_state_dict(
        torch.load(config.reference_unet_path, map_location="cpu"),
    )
    pose_guider.load_state_dict(
        torch.load(config.pose_guider_path, map_location="cpu"),
    )
    # 使用 pipeline_parsing2vid_adapter.Parsing2VideoPipeline 时需要 adapter
    # adapter.load_state_dict(
    #     torch.load(config.adapter_path, map_location="cpu")
    # )

    pipe = Parsing2VideoPipeline(
        vae=vae,
        image_encoder=image_enc,
        reference_unet=reference_unet,
        denoising_unet=denoising_unet,
        pose_guider=pose_guider,
        scheduler=scheduler,
    )
    pipe = pipe.to("cuda", dtype=weight_dtype)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    face_detector = facer.face_detector('retinaface/mobilenet', device=device)
    face_parser = facer.face_parser('farl/lapa/448', device=device) # optional "farl/celebm/448"
    face_aligner = facer.face_aligner('farl/wflw/448', device=device)

    # create save folder
    if args.save_dir == None:
        date_str = datetime.now().strftime("%Y%m%d")
        time_str = datetime.now().strftime("%H%M")
        save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
        save_dir = Path(f"output/{date_str}/{save_dir_name}")
        save_dir.mkdir(exist_ok=True, parents=True)
    else:
        save_dir = args.save_dir

    if args.crop:
        img_cropped = crop(args.reference_img, args.width, args.height)
        img_cropped.save(os.path.join(save_dir, cropped_ref_name))
        args.reference_img = os.path.join(save_dir, cropped_ref_name)

    # get imgs and poses
    faces_ref = get_reference_parsing(args.reference_img, save_dir, face_detector, face_parser, face_aligner, device)
    fps, tgt_h, tgt_w = read_frames(args.target_video, os.path.join(save_dir, 'pngs'))
    frames = sorted(os.listdir(os.path.join(save_dir, 'pngs')))
    print('get target parsings')
    faces_tgt = []
    frames_tgt = []
    for idx, frame in enumerate(tqdm(frames)):
        frame_path = os.path.join(save_dir, 'pngs', frame)
        faces = get_target_parsing(frame_path, face_detector, device)
        if faces != None:
            faces_tgt.append(faces)
            frames_tgt.append(Image.open(frame_path))
    if args.align and faces_ref != None:
        align(faces_ref, faces_tgt, frames_tgt, [tgt_h, tgt_w], save_dir, face_detector, face_parser, face_aligner, args.width, args.height, device)
    
    del faces_ref, faces_tgt
    # print(gt_list[0].shape)

    # inference
    ref_img_pil = Image.open(args.reference_img)
    ref_img_pil = ref_img_pil.convert("RGB")
    if os.path.exists(os.path.join(save_dir, parsing_name)):
        ref_pose_pil = Image.open(os.path.join(save_dir, parsing_name))
        ref_pose_pil = ref_pose_pil.resize((512, 512))
    else:
        ref_pose_pil = None
    tgt_poses_pil = []
    tgt_poses_name = []    # 用来存储姿势图片的名字, 同时也是真实图片的名字, 有些图片可能没有人脸, 所以要记录哪些图片是合法的
    tgt_pose_folder = panc_name if not args.align else aligned_folder
    for tgt_pose_img in sorted(os.listdir(os.path.join(save_dir, tgt_pose_folder))):
        tgt_poses_pil.append(Image.open(os.path.join(save_dir, tgt_pose_folder, tgt_pose_img)))
        tgt_poses_name.append(tgt_pose_img)
    # 使用 pipeline_parsing2vid_adapter.Parsing2VideoPipeline 时需要 adapter
    video = pipe(
        ref_img_pil,
        ref_pose_pil,
        tgt_poses_pil,
        width,
        height,
        len(tgt_poses_pil),
        args.steps,
        args.cfg,
        generator=generator,
        # adapter=adapter
    ).videos

    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    ref_img_tensor = transform(ref_img_pil)
    ref_img_tensor = ref_img_tensor.unsqueeze(1).unsqueeze(0)
    ref_img_tensor = repeat(ref_img_tensor, 'b c f h w -> b c (repeat f) h w', repeat=len(tgt_poses_pil))
    pose_tensor_list = []
    for tgt_pose_pil in tgt_poses_pil:
        pose_tensor_list.append(transform(tgt_pose_pil))
    pose_tensor = torch.stack(pose_tensor_list, dim=0)    # (f, 3, 512, 512)
    pose_tensor = pose_tensor.transpose(0, 1)
    pose_tensor = pose_tensor.unsqueeze(0)
    # gt_list = []
    # gt_frames = sorted(os.listdir(os.path.join(save_dir, cropped_tgt_name)))
    # for frame in gt_frames:
    #     gt_img = Image.open(os.path.join(save_dir, cropped_tgt_name, frame))
    #     gt_tensor = transform(gt_img)
    #     gt_list.append(gt_tensor)
    
    gt_tensor = []
    for i, gt_img in enumerate(sorted(os.listdir(os.path.join(save_dir, cropped_tgt_name)))):
        gt_pil = Image.open(os.path.join(save_dir, cropped_tgt_name, gt_img))
        gt_tensor.append(transform(gt_pil))
    gt = torch.stack(gt_tensor, dim=0)
    gt = gt.transpose(0, 1)
    gt = gt.unsqueeze(0)

    print(ref_img_tensor.shape, pose_tensor.shape, video.shape, gt.shape)
    video = torch.cat([ref_img_tensor, gt, video], dim=-1)
    ref_name = 'test'
    pose_name = 'test'
    save_videos_grid_cv(
        video,
        f"{save_dir}/result.mp4",
        fps=fps if args.fps is None else args.fps,
    )
    print(f'finish to {save_dir}')

if __name__ == "__main__":
    main()
