import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import pickle
import json

# import cv2
from PIL import Image
from torchvision import transforms
import torchaudio

from config import cfg


def get_v2_pallete(label_to_idx_path, num_cls=71):
    def _getpallete(num_cls=71):
        """build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2),
        71 is the total category number of V2 dataset, you should not change that"""
        n = num_cls
        pallete = [0] * (n * 3)
        for j in range(0, n):
            lab = j
            pallete[j * 3 + 0] = 0
            pallete[j * 3 + 1] = 0
            pallete[j * 3 + 2] = 0
            i = 0
            while (lab > 0):
                pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
                pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
                pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
                i = i + 1
                lab >>= 3
        return pallete # list, lenth is n_classes*3

    with open(label_to_idx_path, 'r') as fr:
        label_to_pallete_idx = json.load(fr)
    v2_pallete = _getpallete(num_cls) # list
    v2_pallete = np.array(v2_pallete).reshape(-1, 3)
    assert len(v2_pallete) == len(label_to_pallete_idx)
    return v2_pallete


def crop_resize_img(crop_size, img, img_is_mask=False):
    outsize = crop_size
    short_size = outsize
    w, h = img.size
    if w > h:
        oh = short_size
        ow = int(1.0 * w * oh / h)
    else:
        ow = short_size
        oh = int(1.0 * h * ow / w)
    if not img_is_mask:
        img = img.resize((ow, oh), Image.BILINEAR)
    else:
        img = img.resize((ow, oh), Image.NEAREST)
    # center crop
    w, h = img.size
    x1 = int(round((w - outsize) / 2.))
    y1 = int(round((h - outsize) / 2.))
    img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
    # print("crop for train. set")
    return img

def resize_img(crop_size, img, img_is_mask=False):
    outsize = crop_size
    # only resize for val./test. set
    if not img_is_mask:
        img = img.resize((outsize, outsize), Image.BILINEAR)
    else:
        img = img.resize((outsize, outsize), Image.NEAREST)
    return img

def color_mask_to_label(mask, v_pallete):
    mask_array = np.array(mask).astype('int32')
    semantic_map = []
    for colour in v_pallete:
        equality = np.equal(mask_array, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    # pdb.set_trace() # there is only one '1' value for each pixel, run np.sum(semantic_map, axis=-1)
    label = np.argmax(semantic_map, axis=-1)
    return label


def load_image_in_PIL_to_Tensor(path, split='train', mode='RGB', transform=None):
    img_PIL = Image.open(path).convert(mode)
    if cfg.DATA.CROP_IMG_AND_MASK:
        if split == 'train':
            img_PIL = crop_resize_img(cfg.DATA.CROP_SIZE, img_PIL, img_is_mask=False)
        else:
            img_PIL = resize_img(cfg.DATA.CROP_SIZE, img_PIL, img_is_mask=False)
    if transform:
        img_tensor = transform(img_PIL)
        return img_tensor
    return img_PIL


def load_color_mask_in_PIL_to_Tensor(path, v_pallete, split='train', mode='RGB'):
    color_mask_PIL = Image.open(path).convert(mode)
    if cfg.DATA.CROP_IMG_AND_MASK:
        if split == 'train':
            color_mask_PIL = crop_resize_img(cfg.DATA.CROP_SIZE, color_mask_PIL, img_is_mask=True)
        else:
            color_mask_PIL = resize_img(cfg.DATA.CROP_SIZE, color_mask_PIL, img_is_mask=True)
    # obtain semantic label
    color_label = color_mask_to_label(color_mask_PIL, v_pallete)
    color_label = torch.from_numpy(color_label) # [H, W]
    color_label = color_label.unsqueeze(0)
    # binary_mask = (color_label != (cfg.NUM_CLASSES-1)).float()
    # return color_label, binary_mask # both [1, H, W]
    return color_label # both [1, H, W]
    

def load_audio_lm(audio_lm_path):
    with open(audio_lm_path, 'rb') as fr:
        audio_log_mel = pickle.load(fr)
    audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64]
    return audio_log_mel


class V2Dataset(Dataset):
    """Dataset for audio visual semantic segmentation of AVSBench-semantic (V2)"""
    def __init__(self, split='train', frame_shift=10, debug_flag=False, text_path=''):
        super(V2Dataset, self).__init__()
        self.split = split
        self.frame_shift = frame_shift
        self.mask_num = cfg.MASK_NUM
        df_all = pd.read_csv(cfg.DATA.META_CSV_PATH, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        if debug_flag:
            self.df_split = self.df_split[:100]
        print("{}/{} videos are used for {}.".format(len(self.df_split), len(df_all), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.v2_pallete = get_v2_pallete(cfg.DATA.LABEL_IDX_PATH, num_cls=cfg.NUM_CLASSES)
        self.text_out = False
        if text_path and text_path != 'no':
            df_text = pd.read_csv(text_path)
            text_col = 'object' if 'object' in df_text else 'text'
            df_text[text_col] = df_text[text_col].fillna('sounding object')
            self.df_text = df_text
            self.text_out = True
            self.text_col = text_col

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name, subset = df_one_video['uid'], df_one_video['label']
        img_base_path =  os.path.join(cfg.DATA.DIR_BASE, subset, video_name, 'frames')
        audio_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name)
        color_mask_base_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name, 'labels_rgb')

        if subset == 'v1s': # data from AVSBench-object single-source subset (5s, gt is only the first annotated frame)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
            gt_temporal_mask_flag  = torch.Tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])#.bool()
        elif subset == 'v1m': # data from AVSBench-object multi-sources subset (5s, all 5 extracted frames are annotated)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
            gt_temporal_mask_flag  = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
        elif subset == 'v2': # data from newly collected videos in AVSBench-semantic (10s, all 10 extracted frames are annotated))
            vid_temporal_mask_flag = torch.ones(10)#.bool()
            gt_temporal_mask_flag = torch.ones(10)#.bool()

        # img_path_list = sorted(os.listdir(img_base_path)) # 5 for v1, 10 for new v2
        # imgs_num = len(img_path_list)
        # imgs_pad_zero_num = 10 - imgs_num
        # imgs = []
        # for img_id in range(imgs_num):
        #     img_path = os.path.join(img_base_path, "%d.jpg"%(img_id))
        #     img = load_image_in_PIL_to_Tensor(img_path, split=self.split, transform=self.img_transform)
        #     imgs.append(img)
        # for pad_i in range(imgs_pad_zero_num): #! pad black image?
        #     img = torch.zeros_like(img)
        #     imgs.append(img)

        labels = []
        mask_path_list = sorted(os.listdir(color_mask_base_path))
        for mask_path in mask_path_list:
            if not mask_path.endswith(".png"):
                mask_path_list.remove(mask_path)
        mask_num = len(mask_path_list)
        if self.split != 'train':
            if subset == 'v2':
                assert mask_num == 10
            else:
                assert mask_num == 5

        mask_num = len(mask_path_list)
        label_pad_zero_num = 10 - mask_num
        for mask_id in range(mask_num):
            mask_path = os.path.join(color_mask_base_path, "%d.png"%(mask_id))
            # mask_path =  os.path.join(color_mask_base_path, mask_path_list[mask_id])
            color_label = load_color_mask_in_PIL_to_Tensor(mask_path, v_pallete=self.v2_pallete, split=self.split)
            # print('color_label.shape: ', color_label.shape)
            labels.append(color_label)
        for pad_j in range(label_pad_zero_num):
            color_label = torch.zeros_like(color_label)
            labels.append(color_label)

        # imgs_tensor = torch.stack(imgs, dim=0)
        labels_tensor = torch.stack(labels, dim=0)

        # load audio
        # audio_mapper = AudioMapper(audio_path, 64, 512, self.frame_shift, subset)
        # audio_log_mel = audio_mapper.get()

        # assert imgs_tensor.shape[0] == audio_log_mel.shape[0]

        # return imgs_tensor, audio_log_mel, labels_tensor, \
        #      vid_temporal_mask_flag, gt_temporal_mask_flag, video_name, audio_path

        audio_clip_length = 10 if subset == 'v2' else 5

        if self.text_out:
            text = self.df_text.sort_values(by=['name', 'frame']).loc[(self.df_text.name == video_name), self.text_col].values.tolist()
            return img_base_path, labels_tensor, gt_temporal_mask_flag, video_name, audio_path, audio_clip_length, "<sep>".join(text)

        return img_base_path, labels_tensor, gt_temporal_mask_flag, video_name, audio_path, audio_clip_length


    def __len__(self):
        return len(self.df_split)

    @property
    def num_classes(self):
        """Number of categories (including background)."""
        return cfg.NUM_CLASSES

    @property
    def classes(self):
        """Category names."""
        with open(cfg.DATA.LABEL_IDX_PATH, 'r') as fr:
            classes = json.load(fr)
        return [label for label in classes.keys()]


class AVSSDataset(Dataset):
    """Dataset for audio visual semantic segmentation of AVSBench-semantic (V3/V2)"""
    def __init__(self, split='train', frame_shift=10, debug_flag=False, device='cpu', subdomain='V3', text_path=None):
        super(AVSSDataset, self).__init__()
        self.split = split
        self.frame_shift = frame_shift
        self.mask_num = cfg.MASK_NUM
        self.subdomain = subdomain
        if subdomain == 'V3':
            split_file_path = {
                'train': cfg.DATA.META_SEEN_TRAIN_PATH,
                'val': cfg.DATA.META_SEEN_VAL_PATH,
                'test': cfg.DATA.META_UNSEEN_PATH
            }
            self.df_split = pd.read_csv(split_file_path[split], sep=',')
        elif subdomain == 'AVSS':
            df_all = pd.read_csv(cfg.DATA.META_CSV_PATH, sep=',')
            self.df_split = df_all[df_all['split'] == split]
        else:
            raise NotImplementedError

        if debug_flag:
            self.df_split = self.df_split[:100]
        print("{} videos are used for {}.".format(len(self.df_split), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.v2_pallete = get_v2_pallete(cfg.DATA.LABEL_IDX_PATH, num_cls=cfg.NUM_CLASSES)
        self.text_out = False
        if text_path and text_path != 'no':
            df_text = pd.read_csv(text_path)
            text_col = 'object' if 'object' in df_text else 'text'
            df_text[text_col] = df_text[text_col].fillna('sound emit object')
            self.df_text = df_text
            self.text_out = True
            self.text_col = text_col

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name, subset = df_one_video['uid'], df_one_video['label']
        img_base_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name, 'frames')
        # layer_base_path = os.path.join(cfg.DATA.LAYER_FEAT_PATH, subset, video_name)
        audio_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name)
        color_mask_base_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name, 'labels_rgb')

        if subset == 'v1s': # data from AVSBench-object single-source subset (5s, gt is only the first annotated frame)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
            gt_temporal_mask_flag  = torch.Tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])#.bool()
        elif subset == 'v1m': # data from AVSBench-object multi-sources subset (5s, all 5 extracted frames are annotated)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
            gt_temporal_mask_flag  = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
        elif subset == 'v2': # data from newly collected videos in AVSBench-semantic (10s, all 10 extracted frames are annotated))
            vid_temporal_mask_flag = torch.ones(10)#.bool()
            gt_temporal_mask_flag = torch.ones(10)#.bool()

        # img_path_list = sorted(os.listdir(img_base_path)) # 5 for v1, 10 for new v2
        # imgs_num = len(img_path_list)
        # imgs_pad_zero_num = 10 - imgs_num
        # imgs = []
        # for img_id in range(imgs_num):
        #     img_path = os.path.join(img_base_path, "%d.jpg"%(img_id))
        #     img = load_image_in_PIL_to_Tensor(img_path, split=self.split, transform=self.img_transform)
        #     imgs.append(img)
        # for pad_i in range(imgs_pad_zero_num): #! pad black image?
        #     img = torch.zeros_like(img)
        #     imgs.append(img)
        # layers = [torch.Tensor(np.load(os.path.join(layer_base_path, f"l{i+1}.npy"))) for i in range(4)]
        # layers = [torch.Tensor([0,0,0]) for _ in range(4)]

        labels = []
        mask_path_list = sorted(os.listdir(color_mask_base_path))
        for mask_path in mask_path_list:
            if not mask_path.endswith(".png"):
                mask_path_list.remove(mask_path)
        mask_num = len(mask_path_list)
        if self.split != 'train':
            if subset == 'v2':
                assert mask_num == 10
            else:
                assert mask_num == 5

        mask_num = len(mask_path_list)
        label_pad_zero_num = 10 - mask_num
        for mask_id in range(mask_num):
            mask_path = os.path.join(color_mask_base_path, "%d.png"%(mask_id))
            # mask_path =  os.path.join(color_mask_base_path, mask_path_list[mask_id])
            color_label = load_color_mask_in_PIL_to_Tensor(mask_path, v_pallete=self.v2_pallete, split=self.split)
            # print('color_label.shape: ', color_label.shape)
            labels.append(color_label)
        for pad_j in range(label_pad_zero_num):
            color_label = torch.zeros_like(color_label)
            labels.append(color_label)

        # imgs_tensor = torch.stack(imgs, dim=0)
        labels_tensor = torch.stack(labels, dim=0)

        audio_clip_length = 10 if subset == 'v2' else 5

        if self.text_out:
            text = self.df_text.sort_values(by=['name', 'frame']).loc[(self.df_text.name == video_name), self.text_col].values.tolist()
            return img_base_path, labels_tensor, vid_temporal_mask_flag, video_name, audio_path, audio_clip_length, "<sep>".join(text)

        return img_base_path, labels_tensor, vid_temporal_mask_flag, video_name, audio_path, audio_clip_length

    def __len__(self):
        return len(self.df_split)

    @property
    def num_classes(self):
        """Number of categories (including background)."""
        return cfg.NUM_CLASSES

    @property
    def classes(self):
        """Category names."""
        with open(cfg.DATA.LABEL_IDX_PATH, 'r') as fr:
            classes = json.load(fr)
        return [label for label in classes.keys()]


class V3Dataset_zs(Dataset):
    """Dataset for audio visual semantic segmentation of AVSBench-semantic (V3)"""
    def __init__(self, split='train', frame_shift=10, debug_flag=False, text_path=None):
        super(V3Dataset_zs, self).__init__()
        self.split = split
        self.frame_shift = frame_shift
        self.mask_num = cfg.MASK_NUM
        # df_all = pd.read_csv(cfg.DATA.META_CSV_PATH, sep=',')
        # self.df_split = df_all[df_all['split'] == split]
        split_file_path = {
            'train': cfg.DATA.META_SEEN_TRAIN_PATH,
            'val': cfg.DATA.META_SEEN_VAL_PATH,
            'test': cfg.DATA.META_UNSEEN_PATH
        }
        self.df_split = pd.read_csv(split_file_path[split], sep=',')
        if debug_flag:
            self.df_split = self.df_split[:100]
        print("{} videos are used for {}.".format(len(self.df_split), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.v2_pallete = get_v2_pallete(cfg.DATA.LABEL_IDX_PATH, num_cls=cfg.NUM_CLASSES)
        # self.pengi = Pengi(config="base", use_cuda=True) # (device=='cuda')
        self.text_out = False
        if text_path and text_path != 'no':
            df_text = pd.read_csv(text_path)
            text_col = 'object' if 'object' in df_text else 'text'
            df_text[text_col] = df_text[text_col].fillna('sounding object')
            self.df_text = df_text
            self.text_out = True
            self.text_col = text_col

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name, subset = df_one_video['uid'], df_one_video['label']
        img_base_path =  os.path.join(cfg.DATA.DIR_BASE, subset, video_name, 'frames')
        # layer_base_path = os.path.join(cfg.DATA.LAYER_FEAT_PATH, subset, video_name)
        audio_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name)
        color_mask_base_path = os.path.join(cfg.DATA.DIR_BASE, subset, video_name, 'labels_rgb')

        if subset == 'v1s': # data from AVSBench-object single-source subset (5s, gt is only the first annotated frame)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
            gt_temporal_mask_flag  = torch.Tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])#.bool()
        elif subset == 'v1m': # data from AVSBench-object multi-sources subset (5s, all 5 extracted frames are annotated)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
            gt_temporal_mask_flag  = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])#.bool()
        elif subset == 'v2': # data from newly collected videos in AVSBench-semantic (10s, all 10 extracted frames are annotated))
            vid_temporal_mask_flag = torch.ones(10)#.bool()
            gt_temporal_mask_flag = torch.ones(10)#.bool()

        # img_path_list = sorted(os.listdir(img_base_path)) # 5 for v1, 10 for new v2
        # imgs_num = len(img_path_list)
        # imgs_pad_zero_num = 10 - imgs_num
        # imgs = []
        # for img_id in range(imgs_num):
        #     img_path = os.path.join(img_base_path, "%d.jpg"%(img_id))
        #     img = load_image_in_PIL_to_Tensor(img_path, split=self.split, transform=self.img_transform)
        #     imgs.append(img)
        # for pad_i in range(imgs_pad_zero_num): #! pad black image?
        #     img = torch.zeros_like(img)
        #     imgs.append(img)
        # layers = [torch.Tensor(np.load(os.path.join(layer_base_path, f"l{i+1}.npy"))) for i in range(4)]
        # layers = [torch.Tensor([0,0,0]) for _ in range(4)]

        labels = []
        mask_path_list = sorted(os.listdir(color_mask_base_path))
        for mask_path in mask_path_list:
            if not mask_path.endswith(".png"):
                mask_path_list.remove(mask_path)
        mask_num = len(mask_path_list)
        if self.split != 'train':
            if subset == 'v2':
                assert mask_num == 10
            else:
                assert mask_num == 5

        mask_num = len(mask_path_list)
        label_pad_zero_num = 10 - mask_num
        for mask_id in range(mask_num):
            mask_path = os.path.join(color_mask_base_path, "%d.png"%(mask_id))
            # mask_path =  os.path.join(color_mask_base_path, mask_path_list[mask_id])
            color_label = load_color_mask_in_PIL_to_Tensor(mask_path, v_pallete=self.v2_pallete, split=self.split)
            # print('color_label.shape: ', color_label.shape)
            labels.append(color_label)
        for pad_j in range(label_pad_zero_num):
            color_label = torch.zeros_like(color_label)
            labels.append(color_label)

        # imgs_tensor = torch.stack(imgs, dim=0)
        labels_tensor = torch.stack(labels, dim=0)

        audio_clip_length = 10 if subset == 'v2' else 5

        if self.text_out:
            text = self.df_text.sort_values(by=['name', 'frame']).loc[(self.df_text.name == video_name), self.text_col].values.tolist()
            return img_base_path, labels_tensor, gt_temporal_mask_flag, video_name, audio_path, audio_clip_length, "<sep>".join(text)

        return img_base_path, labels_tensor, gt_temporal_mask_flag, video_name, audio_path, audio_clip_length

        # return imgs_tensor, "#".join(nouns), max_num_noun, labels_tensor, \
        #      vid_temporal_mask_flag, gt_temporal_mask_flag, video_name, audio_path

        # return layers, audio_log_mel, labels_tensor, \
        #      vid_temporal_mask_flag, gt_temporal_mask_flag, video_name, audio_path


    def __len__(self):
        return len(self.df_split)

    @property
    def num_classes(self):
        """Number of categories (including background)."""
        return cfg.NUM_CLASSES

    @property
    def classes(self):
        """Category names."""
        with open(cfg.DATA.LABEL_IDX_PATH, 'r') as fr:
            classes = json.load(fr)
        return [label for label in classes.keys()]


class AudioMapper:
    def __init__(self, audio_path, melbins, target_length, frame_shift, subset):
        mean, std = -4.2677393, 4.5689974
        self.audio_path = audio_path
        self.melbins = melbins
        self.target_length = target_length
        self.mean = mean
        self.std = std
        self.subset = subset
        # self.training = training
        self.frame_shift = frame_shift
        # self.sample_num = sample_num

    def get(self):

        # wav_file = os.path.join(self.audio_dir, id_+'.wav')
        wav_file = self.audio_path
        if not os.path.exists(wav_file):
            wav_file = wav_file.replace('wav','mkv')
        if not os.path.exists(wav_file):
            return torch.zeros(10, self.target_length, self.melbins)
        
        try:
            #### has no audio channel, use zero instead
            # LOGGER.info(f'{id_} has no audio file, use zero instead')

            waveform, sr = torchaudio.load(wav_file)

            waveform = waveform - waveform.mean()
            fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
                                                    window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=self.frame_shift)

            #### fbank shape :(src_length,64)
            # src_length = fbank.shape[0]

            # # #### sample 

            num_splits = 10 if self.subset == 'v2' else 5
            output_slices = []
            src_length = int(np.ceil(fbank.shape[0] / num_splits))

            # out_num = 1 if self.training else 5

            for i in range(num_splits):
                src = fbank[i*src_length:min((i+1)*src_length, fbank.shape[0])]
                pad_len = self.target_length - src.shape[0]
                dst = torch.nn.ZeroPad2d((0,0,0, pad_len))(src)
                output_slices.append(dst)

            ### length zero padding
            if self.subset != 'v2':
                output_slices.extend(
                    [torch.zeros(self.target_length, self.melbins) for _ in range(10 - num_splits)]
                )

            fbank = torch.stack(output_slices, dim=0).permute(0,2,1)

            ### normalization
            fbank = (fbank - self.mean) / (self.std * 2)

            #return fbank.permute(1,0)  ### 128, target_length

            return fbank

        except Exception as e:
            print(e)
            return


def split(frame_name_lists, sample_num):
    if len(frame_name_lists) < sample_num:   ###padding with the last frame
        frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists))
    k, m = divmod(len(frame_name_lists), sample_num)
    return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))]




if __name__ == '__main__':
    ds = V2Dataset(debug_flag=True)
    sample = ds[0]