import os, io, csv, math, random
from importlib.metadata import files
import os.path as osp

import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from einops import rearrange
import cv2
import warnings


class LargeScaleAnimationVideos4(Dataset):
    def __init__(self, root_path, txt_path, width, height, n_sample_frames, sample_frame_rate=4, sample_margin=30,
                 app=None, handler_ante=None, face_helper=None, chunks_per_video=20, frames_per_chunk=70):
        self.root_path = root_path
        self.txt_path = txt_path
        self.width = width 
        self.height = height
        self.n_sample_frames = n_sample_frames
        self.sample_frame_rate = sample_frame_rate

        self.sample_margin = sample_margin
        self.chunks_per_video = chunks_per_video


        self.base_video_files = self._read_txt_file_images()
        

        self.video_chunks = self._create_video_chunks(chunks_per_video, frames_per_chunk)
        self.video_chunks_num = len(self.video_chunks)

        self.app = app
        self.handler_ante = handler_ante
        self.face_helper = face_helper
        
        print(f"Dataset created with {len(self.video_chunks)} chunks from {len(self.base_video_files)} videos")
        
    def _read_txt_file_images(self):
        with open(self.txt_path, 'r') as file:
            lines = file.readlines()
            video_files = []
            for line in lines:
                video_file = line.strip()
                video_files.append(video_file)
        return video_files


    def _create_video_chunks(self, chunks_per_video=64, frames_per_chunk=35):
        """
        Create chunks from each video, scaling the number of chunks based on video length
        but limited by chunks_per_video parameter.
        
        Each chunk requires at least 80 frames, and chunks are evenly distributed.
        
        Examples:
        - A video with 400 frames and chunks_per_video=20 will have 5 chunks (400/80)
        - A video with 8000 frames and chunks_per_video=20 will have 20 chunks (limited by parameter)
        """
        video_chunks = []
        
        for video_file in self.base_video_files:
            frames_path = os.path.join(video_file, "images")
            if not os.path.exists(frames_path):
                print(f"Skipping {video_file} - no images directory")
                continue
                
            try:

                all_frames = [f for f in os.listdir(frames_path) 
                            if f.endswith(('.jpg', '.png', '.jpeg'))]
                

                frame_files = []
                for frame in all_frames:

                    base_name = os.path.splitext(frame)[0]
                    if base_name.startswith('frame_'):
                        try:
                            frame_num = int(base_name.split('_')[1])
                            if frame_num <= 9999:
                                frame_files.append(frame)
                        except (ValueError, IndexError):

                            continue
                

                frame_files.sort(key=lambda x: int(os.path.splitext(x)[0].split('_')[1]))
                
                video_length = len(frame_files)
                

                if video_length < self.n_sample_frames:
                    print(f"Skipping {video_file} - too few frames ({video_length})")
                    continue
                    

                stride = self.sample_frame_rate
                

                required_frame_span = (self.n_sample_frames - 1) * stride + 1
                

                if video_length < required_frame_span:

                    stride = max(1, (video_length - 1) // (self.n_sample_frames - 1))
                    required_frame_span = (self.n_sample_frames - 1) * stride + 1
                    print(f"Reducing stride to {stride} for {video_file}")
                

                available_frames = video_length - required_frame_span + 1
                
                if available_frames <= 0:
                    print(f"Video {video_file} cannot fit any chunks with current settings")
                    continue


                length_based_chunks = max(1, video_length // frames_per_chunk)
                

                num_chunks = min(chunks_per_video, length_based_chunks) if chunks_per_video > 0 else length_based_chunks
                

                if num_chunks == 0:
                    num_chunks = 1
                    
                if num_chunks == 1:

                    start_positions = [available_frames // 2]
                else:


                    step = available_frames / (num_chunks - 1) if num_chunks > 1 else 0
                    start_positions = [int(i * step) for i in range(num_chunks)]
                

                for start_idx in start_positions:

                    start_idx = min(start_idx, available_frames - 1)
                    video_chunks.append((video_file, start_idx, stride))
                
                print(f"Created {num_chunks} chunks from video {os.path.basename(video_file)} with {video_length} frames")
                
            except Exception as e:
                print(f"Error processing {video_file}: {e}")
                continue
                
        print(f"Created {len(video_chunks)} total chunks from {len(self.base_video_files)} videos")
        return video_chunks
        

    def __len__(self):
        return len(self.video_chunks)

    def frame_count(self, frames_path):
        files = os.listdir(frames_path)
        png_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')]
        png_files_count = len(png_files)
        return png_files_count

    def find_frames_list(self, frames_path):
        files = os.listdir(frames_path)
        image_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')]
        if image_files[0].startswith('frame_'):
            image_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
        else:
            image_files.sort(key=lambda x: int(x.split('.')[0]))
        return image_files

    def get_face_masks(self, pil_img):
        rgb_image = np.array(pil_img)
        bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
        image_info = self.app.get(bgr_image)
        mask = np.zeros((self.height, self.width), dtype=np.uint8)

        if len(image_info) > 0:
            for info in image_info:
                x_1 = info['bbox'][0]
                y_1 = info['bbox'][1]
                x_2 = info['bbox'][2]
                y_2 = info['bbox'][3]
                cv2.rectangle(mask, (int(x_1), int(y_1)), (int(x_2), int(y_2)), (255), thickness=cv2.FILLED)
            mask = mask.astype(np.float64) / 255.0
        else:
            self.face_helper.clean_all()
            with torch.no_grad():
                bboxes = self.face_helper.face_det.detect_faces(bgr_image, 0.97)
                if len(bboxes) > 0:
                    for bbox in bboxes:
                        cv2.rectangle(mask, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255),
                                    thickness=cv2.FILLED)
                    mask = mask.astype(np.float64) / 255.0
                else:
                    mask = np.ones((self.height, self.width), dtype=np.uint8)
        return mask

    def __getitem__(self, idx):
        warnings.filterwarnings('ignore', category=DeprecationWarning)
        warnings.filterwarnings('ignore', category=FutureWarning)


        video_file, start_idx, stride = self.video_chunks[idx]
        

        frames_path = os.path.join(video_file, "images")
        face_masks_path = os.path.join(video_file, "faces/allperson")
        poses_path0 = os.path.join(video_file, "poses/person_0")
        face_masks_path0 = os.path.join(video_file, "faces/person_0")
        human_masks_path0 = os.path.join(video_file, "masks/person_0")
        poses_path1 = os.path.join(video_file, "poses/person_1")
        face_masks_path1 = os.path.join(video_file, "faces/person_1")
        human_masks_path1 = os.path.join(video_file, "masks/person_1")

        video_length = self.frame_count(frames_path)
        frames_list = self.find_frames_list(frames_path)


        end_idx = start_idx + (self.n_sample_frames - 1) * stride
        

        batch_index = [start_idx + i * stride for i in range(self.n_sample_frames)]
        

        if end_idx >= video_length:

            available_frames = video_length - start_idx
            new_stride = max(1, (available_frames - 1) // (self.n_sample_frames - 1))
            batch_index = [start_idx + i * new_stride for i in range(self.n_sample_frames)]
        

        all_indices = list(range(0, video_length))
        available_indices = [i for i in all_indices if i not in batch_index]
        

        chunk_center = (batch_index[0] + batch_index[-1]) // 2
        available_indices.sort(key=lambda x: abs(x - chunk_center))
        

        if available_indices:

            reference_frame_idx = available_indices[0]
        else:

            print(f"No available reference frame in {frames_path}")
            reference_frame_idx = batch_index[0]


        pose_pil_image_list0 = []
        pose_pil_image_list1 = []
        tgt_pil_image_list = []
        tgt_face_masks_list0 = []
        tgt_face_masks_list1 = []
        tgt_human_masks_list0 = []
        tgt_human_masks_list1 = []


        reference_frame_path = os.path.join(frames_path, frames_list[reference_frame_idx])
        
        reference_mask_name = os.path.splitext(os.path.basename(reference_frame_path))[0] + '.png'
        reference_human_mask_path0 = os.path.join(human_masks_path0, reference_mask_name)
        reference_human_mask_path1 = os.path.join(human_masks_path1, reference_mask_name)

        try:


            if os.path.exists(reference_human_mask_path0):
                reference_human_mask0 = Image.open(reference_human_mask_path0)
            elif os.path.exists(reference_human_mask_path0.replace('frame_', 'frame')):
                reference_human_mask0 = Image.open(reference_human_mask_path0.replace('frame_', 'frame'))
            else:
                raise FileNotFoundError(f"Neither {reference_human_mask_path0} nor {reference_human_mask_path0.replace('frame_', 'frame')} exists")
                reference_human_mask0 = torch.ones(self.height, self.width)
            

            reference_human_mask0 = reference_human_mask0.resize((self.width, self.height))
            reference_human_mask0 = torch.from_numpy(np.array(reference_human_mask0)).float() / 255.
        except Exception as e:
            print(f"Fail loading reference human mask for person 0: {reference_human_mask_path0}")
            reference_human_mask0 = torch.ones(self.height, self.width)

        try:
            if os.path.exists(reference_human_mask_path1):
                reference_human_mask1 = Image.open(reference_human_mask_path1)
            elif os.path.exists(reference_human_mask_path1.replace('frame_', 'frame')):
                reference_human_mask1 = Image.open(reference_human_mask_path1.replace('frame_', 'frame'))
            else:
                raise FileNotFoundError(f"Neither {reference_human_mask_path1} nor {reference_human_mask_path1.replace('frame_', 'frame')} exists")
                reference_human_mask1 = torch.ones(self.height, self.width)


            reference_human_mask1 = reference_human_mask1.resize((self.width, self.height))
            reference_human_mask1 = torch.from_numpy(np.array(reference_human_mask1)).float() / 255.
        except Exception as e:
            print(f"Fail loading reference human mask for person 1: {reference_human_mask_path1}")
            reference_human_mask1 = torch.ones(self.height, self.width)
            
        reference_pil_image = Image.open(reference_frame_path).convert('RGB')
        reference_pil_image = reference_pil_image.resize((self.width, self.height))
        reference_pil_image = torch.from_numpy(np.array(reference_pil_image)).float()
        reference_pil_image = reference_pil_image / 127.5 - 1

        self.face_helper.clean_all()
        reference_frame_face = cv2.imread(reference_frame_path)
        reference_frame_face = cv2.resize(reference_frame_face, (self.width, self.height))


        reference_frame_basename = os.path.splitext(os.path.basename(reference_frame_path))[0] + '.jpg'
        reference_face_mask_path0 = os.path.join(face_masks_path0, reference_frame_basename)
        reference_face_mask_path1 = os.path.join(face_masks_path1, reference_frame_basename)
        

        reference_frame_id_ante_embedding0 = None
        reference_frame_id_ante_embedding1 = None


        if os.path.exists(reference_face_mask_path0):
            face_mask0 = cv2.imread(reference_face_mask_path0, cv2.IMREAD_GRAYSCALE)
            if face_mask0 is not None and np.sum(face_mask0) > 0:
                face_mask0 = cv2.resize(face_mask0, (self.width, self.height))
                
                y0, x0 = np.where(face_mask0 > 0)
                if len(y0) > 0 and len(x0) > 0:
                    y0_min, y0_max = np.min(y0), np.max(y0)
                    x0_min, x0_max = np.min(x0), np.max(x0)
                    
                    padding = 20
                    y0_min = max(0, y0_min - padding)
                    y0_max = min(self.height, y0_max + padding)
                    x0_min = max(0, x0_min - padding)
                    x0_max = min(self.width, x0_max + padding)
                    
                    face_crop0 = reference_frame_face[y0_min:y0_max, x0_min:x0_max]
                    
                    face_info0 = self.app.get(face_crop0)
                    if len(face_info0) > 0:
                        face_info0 = sorted(face_info0, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
                        reference_frame_id_ante_embedding0 = face_info0['embedding']
                    else:
                        self.face_helper.clean_all()
                        self.face_helper.read_image(face_crop0)
                        self.face_helper.get_face_landmarks_5(only_center_face=True)
                        self.face_helper.align_warp_face()
                        
                        if len(self.face_helper.cropped_faces) > 0:
                            aligned_face0 = self.face_helper.cropped_faces[0]
                            reference_frame_id_ante_embedding0 = self.handler_ante.get_feat(aligned_face0)
                        else:
                            reference_frame_id_ante_embedding0 = np.zeros((512,))
                            print('[ERROR]: reference_frame_id_ante_embedding0 is np.zeros((512,)!')


        if os.path.exists(reference_face_mask_path1):
            face_mask1 = cv2.imread(reference_face_mask_path1, cv2.IMREAD_GRAYSCALE)
            if face_mask1 is not None and np.sum(face_mask1) > 0:
                face_mask1 = cv2.resize(face_mask1, (self.width, self.height))
                
                y1, x1 = np.where(face_mask1 > 0)
                if len(y1) > 0 and len(x1) > 0:
                    y1_min, y1_max = np.min(y1), np.max(y1)
                    x1_min, x1_max = np.min(x1), np.max(x1)
                    
                    padding = 20
                    y1_min = max(0, y1_min - padding)
                    y1_max = min(self.height, y1_max + padding)
                    x1_min = max(0, x1_min - padding)
                    x1_max = min(self.width, x1_max + padding)
                    
                    face_crop1 = reference_frame_face[y1_min:y1_max, x1_min:x1_max]
                    
                    face_info1 = self.app.get(face_crop1)
                    if len(face_info1) > 0:
                        face_info1 = sorted(face_info1, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
                        reference_frame_id_ante_embedding1 = face_info1['embedding']
                    else:
                        self.face_helper.clean_all()
                        self.face_helper.read_image(face_crop1)
                        self.face_helper.get_face_landmarks_5(only_center_face=True)
                        self.face_helper.align_warp_face()
                        
                        if len(self.face_helper.cropped_faces) > 0:
                            aligned_face1 = self.face_helper.cropped_faces[0]
                            reference_frame_id_ante_embedding1 = self.handler_ante.get_feat(aligned_face1)
                        else:
                            reference_frame_id_ante_embedding1 = np.zeros((512,))
                            print('[ERROR]: reference_frame_id_ante_embedding1 is np.zeros((512,)!')


        if reference_frame_id_ante_embedding0 is None:
            reference_frame_id_ante_embedding0 = np.zeros((512,))
            print('[ERROR]: reference_frame_id_ante_embedding0 is None!')
        if reference_frame_id_ante_embedding1 is None:
            reference_frame_id_ante_embedding1 = np.zeros((512,))
            print('[ERROR]: reference_frame_id_ante_embedding1 is None!')


        if isinstance(reference_frame_id_ante_embedding0, np.ndarray):

            reference_frame_id_ante_embedding0 = reference_frame_id_ante_embedding0.reshape(-1)
        if isinstance(reference_frame_id_ante_embedding1, np.ndarray):

            reference_frame_id_ante_embedding1 = reference_frame_id_ante_embedding1.reshape(-1)


        for index in batch_index:

            if index >= len(frames_list):
                print(f"Index {index} out of bounds for {frames_path} with {len(frames_list)} frames")

                index = len(frames_list) - 1
                
            tgt_img_path = os.path.join(frames_path, frames_list[index])
            pose_name = os.path.splitext(os.path.basename(tgt_img_path))[0] + '.jpg'
            face_name = os.path.splitext(os.path.basename(tgt_img_path))[0] + '.jpg'
            mask_name = os.path.splitext(os.path.basename(tgt_img_path))[0] + '.png'
            
            face_mask_path = os.path.join(face_masks_path, face_name)
            pose_path0 = os.path.join(poses_path0, pose_name)
            face_mask_path0 = os.path.join(face_masks_path0, face_name)
            human_mask_path0 = os.path.join(human_masks_path0, mask_name)
            pose_path1 = os.path.join(poses_path1, pose_name)
            face_mask_path1 = os.path.join(face_masks_path1, face_name)
            human_mask_path1 = os.path.join(human_masks_path1, mask_name)

            try:
                tgt_img_pil = Image.open(tgt_img_path).convert('RGB')
            except Exception as e:
                print(f"Fail loading the image: {tgt_img_path}")

                tgt_img_pil = Image.new('RGB', (self.width, self.height), (0, 0, 0))

            try:
                tgt_face_mask0 = Image.open(face_mask_path0)
                tgt_face_mask0 = tgt_face_mask0.resize((self.width, self.height))
                tgt_face_mask0 = torch.from_numpy(np.array(tgt_face_mask0)).float()
                tgt_face_mask0 = tgt_face_mask0 / 255

                tgt_face_mask1 = Image.open(face_mask_path1)
                tgt_face_mask1 = tgt_face_mask1.resize((self.width, self.height))
                tgt_face_mask1 = torch.from_numpy(np.array(tgt_face_mask1)).float()
                tgt_face_mask1 = tgt_face_mask1 / 255


                if os.path.exists(human_mask_path0):
                    tgt_human_mask0 = Image.open(human_mask_path0)
                elif os.path.exists(human_mask_path0.replace('frame_', 'frame')):
                    tgt_human_mask0 = Image.open(human_mask_path0.replace('frame_', 'frame'))
                else:
                    raise FileNotFoundError(f"Neither {human_mask_path0} nor {human_mask_path0.replace('frame_', 'frame')} exists")
                
                if os.path.exists(human_mask_path1):
                    tgt_human_mask1 = Image.open(human_mask_path1)
                elif os.path.exists(human_mask_path1.replace('frame_', 'frame')):
                    tgt_human_mask1 = Image.open(human_mask_path1.replace('frame_', 'frame'))
                else:
                    raise FileNotFoundError(f"Neither {human_mask_path1} nor {human_mask_path1.replace('frame_', 'frame')} exists")
                    
                tgt_human_mask0 = tgt_human_mask0.resize((self.width, self.height))
                tgt_human_mask0 = torch.from_numpy(np.array(tgt_human_mask0)).float()
                tgt_human_mask0 = tgt_human_mask0 / 255
                
                tgt_human_mask1 = tgt_human_mask1.resize((self.width, self.height))
                tgt_human_mask1 = torch.from_numpy(np.array(tgt_human_mask1)).float()
                tgt_human_mask1 = tgt_human_mask1 / 255

            except Exception as e:
                print(f"Fail loading masks: {e}")
                tgt_face_mask0 = torch.ones(self.height, self.width, 1)
                tgt_face_mask1 = torch.ones(self.height, self.width, 1)
                tgt_human_mask0 = torch.ones(self.height, self.width, 1)
                tgt_human_mask1 = torch.ones(self.height, self.width, 1)

            tgt_face_masks_list0.append(tgt_face_mask0)
            tgt_face_masks_list1.append(tgt_face_mask1)
            tgt_human_masks_list0.append(tgt_human_mask0)
            tgt_human_masks_list1.append(tgt_human_mask1)

            tgt_img_pil = tgt_img_pil.resize((self.width, self.height))
            tgt_img_tensor = torch.from_numpy(np.array(tgt_img_pil)).float()
            tgt_img_normalized = tgt_img_tensor / 127.5 - 1
            tgt_pil_image_list.append(tgt_img_normalized)

            try:
                pose0 = Image.open(pose_path0).convert('RGB')
                pose0 = pose0.resize((self.width, self.height))
                pose0 = torch.from_numpy(np.array(pose0)).float()
                pose0 = pose0 / 127.5 - 1
                
                pose1 = Image.open(pose_path1).convert('RGB')
                pose1 = pose1.resize((self.width, self.height))
                pose1 = torch.from_numpy(np.array(pose1)).float()
                pose1 = pose1 / 127.5 - 1
            except Exception as e:
                print(f"Fail loading poses: {e}")
                pose0 = torch.zeros_like(reference_pil_image)
                pose1 = torch.zeros_like(reference_pil_image)

            pose_pil_image_list0.append(pose0)
            pose_pil_image_list1.append(pose1)


        pose_pil_image_list0 = torch.stack(pose_pil_image_list0, dim=0)
        pose_pil_image_list1 = torch.stack(pose_pil_image_list1, dim=0)
        tgt_pil_image_list = torch.stack(tgt_pil_image_list, dim=0)
        
        tgt_pil_image_list = rearrange(tgt_pil_image_list, "f h w c -> f c h w")
        reference_pil_image = rearrange(reference_pil_image, "h w c -> c h w")
        pose_pil_image_list0 = rearrange(pose_pil_image_list0, "f h w c -> f c h w")
        pose_pil_image_list1 = rearrange(pose_pil_image_list1, "f h w c -> f c h w")


        for i in range(len(tgt_face_masks_list0)):
            if tgt_face_masks_list0[i].ndim == 2:
                tgt_face_masks_list0[i] = tgt_face_masks_list0[i].unsqueeze(-1)
        tgt_face_masks_list0 = torch.stack(tgt_face_masks_list0, dim=0)
        tgt_face_masks_list0 = torch.unsqueeze(tgt_face_masks_list0, dim=-1)
        if tgt_face_masks_list0.ndim == 5:
            tgt_face_masks_list0 = tgt_face_masks_list0.squeeze(-1)
        tgt_face_masks_list0 = rearrange(tgt_face_masks_list0, "f h w c -> f c h w")

        for i in range(len(tgt_face_masks_list1)):
            if tgt_face_masks_list1[i].ndim == 2:
                tgt_face_masks_list1[i] = tgt_face_masks_list1[i].unsqueeze(-1)
        tgt_face_masks_list1 = torch.stack(tgt_face_masks_list1, dim=0)
        tgt_face_masks_list1 = torch.unsqueeze(tgt_face_masks_list1, dim=-1)
        if tgt_face_masks_list1.ndim == 5:
            tgt_face_masks_list1 = tgt_face_masks_list1.squeeze(-1)
        tgt_face_masks_list1 = rearrange(tgt_face_masks_list1, "f h w c -> f c h w")

        for i in range(len(tgt_human_masks_list0)):
            if tgt_human_masks_list0[i].ndim == 2:
                tgt_human_masks_list0[i] = tgt_human_masks_list0[i].unsqueeze(-1)
        tgt_human_masks_list0 = torch.stack(tgt_human_masks_list0, dim=0)
        tgt_human_masks_list0 = torch.unsqueeze(tgt_human_masks_list0, dim=-1)
        if tgt_human_masks_list0.ndim == 5:
            tgt_human_masks_list0 = tgt_human_masks_list0.squeeze(-1)
        tgt_human_masks_list0 = rearrange(tgt_human_masks_list0, "f h w c -> f c h w")

        for i in range(len(tgt_human_masks_list1)):
            if tgt_human_masks_list1[i].ndim == 2:
                tgt_human_masks_list1[i] = tgt_human_masks_list1[i].unsqueeze(-1)
        tgt_human_masks_list1 = torch.stack(tgt_human_masks_list1, dim=0)
        tgt_human_masks_list1 = torch.unsqueeze(tgt_human_masks_list1, dim=-1)
        if tgt_human_masks_list1.ndim == 5:
            tgt_human_masks_list1 = tgt_human_masks_list1.squeeze(-1)
        tgt_human_masks_list1 = rearrange(tgt_human_masks_list1, "f h w c -> f c h w")


        sample = dict(
            pixel_values=tgt_pil_image_list,
            reference_image=reference_pil_image,
            reference_human_mask0=reference_human_mask0,
            reference_human_mask1=reference_human_mask1,
            pose_pixels0=pose_pil_image_list0,
            pose_pixels1=pose_pil_image_list1,
            faceid_embeds0=reference_frame_id_ante_embedding0,
            faceid_embeds1=reference_frame_id_ante_embedding1,
            tgt_face_masks0=tgt_face_masks_list0,
            tgt_face_masks1=tgt_face_masks_list1,
            tgt_human_masks0=tgt_human_masks_list0,
            tgt_human_masks1=tgt_human_masks_list1,
        )

        return sample


'''
dataset/
├── video1/
│   ├── alphamasks/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   ├── images/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   ├── poses/allperson/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   ├── poses/person_0/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   ├── poses/person_1/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   ├── faces/allperson/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   ├── faces/person_0/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│   │   └── ...
│   └── faces/person_1/
│   │   ├── frame_0.png
│   │   ├── frame_1.png
│       └── ...
│   ├── masks/person_0/
│   │   ├── frame0.png
│   │   ├── frame1.png
│   │   └── ...
│   └── masks/person_1/
│   │   ├── frame0.png
│   │   ├── frame1.png
│       └── ...
└── video2/
    ├── images/
    ├── poses/person_0/
    └── ...
'''
