import av
import ffmpeg
from joblib import Parallel, delayed
from multiprocessing import Manager
import os
import pickle
import random
import torch
import torchvision
import torch.utils.data
import glob

try: 
    from decoder import decode
except:
    from .decoder import decode

try: 
    from video_transforms import (
        random_short_side_scale_jitter, 
        random_crop, 
        horizontal_flip, 
        grayscale, 
        color_jitter, 
        uniform_crop, 
        resize, 
        normalize
    )
except:
    from .video_transforms import (
        random_short_side_scale_jitter, 
        random_crop, 
        horizontal_flip, 
        grayscale, 
        color_jitter, 
        uniform_crop, 
        resize, 
        normalize
    )


# Enable multi thread decoding.
ENABLE_MULTI_THREAD_DECODE = True

# Decoding backend, options include `pyav` or `torchvision`
DECODING_BACKEND = 'pyav'

MEAN=[0.45, 0.45, 0.45]
STD=[0.225, 0.225, 0.225]

ROOT_DIR = {
    'kinetics': '/datasets01_101/kinetics/070618/',
    'kinetics600': '/datasets01_101/kinetics/070618/600/',
    'audioset': '/datasets01_101/audioset/042319/data/',
    'vggsound': 'XXX',
}

MODE_DIR = {
    'kinetics': {
        'train': 'train_avi-480p',
        'val': 'val_avi-480p'
    },
    'kinetics600': {
        'train': 'train',
        'val': 'val'
    },
    'audioset': {
        'train': 'unbalanced_train_segments/video',
        'val': 'eval_segments/video'
    },
    'vggsound': {
        'train': 'train',
        'val': 'test'
    },
}


def valid_video(vid_idx, vid_path):
    try:
        probe = ffmpeg.probe(vid_path)
        video_stream = next((
            stream for stream in probe['streams'] if stream['codec_type'] == 'video'), 
            None
        )
        audio_stream = next((
            stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), 
            None
        )
        if audio_stream and video_stream and float(video_stream['duration']) > 1.1 and float(audio_stream['duration']) > 1.1:
            print(f"{vid_idx}: True", flush=True)
            return True
        else:
            print(f"{vid_idx}: False (duration short/ no audio)", flush=True)
            return False
    except:
        print(f"{vid_idx}: False", flush=True)
        return False


def filter_videos(vid_paths):
    all_indices = Parallel(n_jobs=30)(delayed(valid_video)(vid_idx, vid_paths[vid_idx]) for vid_idx in range(len(vid_paths)))
    valid_indices = [i for i, val in enumerate(all_indices) if val]
    return valid_indices


def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"):
    """
    Given the path to the video, return the pyav video container.
    Args:
        path_to_vid (str): path to the video.
        multi_thread_decode (bool): if True, perform multi-thread decoding.
        backend (str): decoder backend, options include `pyav` and
            `torchvision`, default is `pyav`.
    Returns:
        container (container): video container.
    """
    if backend == "torchvision":
        with open(path_to_vid, "rb") as fp:
            container = fp.read()
        return container
    elif backend == "pyav":
        try:
            container = av.open(path_to_vid)
        except:
            container = av.open(path_to_vid, metadata_errors="ignore")
        if multi_thread_decode:
            # Enable multiple threads for decoding.
            container.streams.video[0].thread_type = "AUTO"
        return container
    else:
        raise NotImplementedError("Unknown backend {}".format(backend))


class AVideoDataset(torch.utils.data.Dataset):
    """
    Audio-video loader. Construct the video loader, then sample
    clips from the videos. For training and validation, a single clip is
    randomly sampled from every video with random cropping, scaling, and
    flipping. For testing, multiple clips are uniformaly sampled from every
    video with uniform cropping. For uniform cropping, we take the left, center,
    and right crop if the width is larger than height, or take top, center, and
    bottom crop if the height is larger than the width.
    """
    def __init__(
        self, 
        ds_name='kinetics',
        mode='train',
        num_frames=30,
        sample_rate=1,
        num_train_clips=1,
        train_crop_size=112,
        test_crop_size=112,
        num_spatial_crops=3,
        num_ensemble_views=10,
        path_to_data_dir='datasets/data',
        seed=None,
        num_data_samples=None,
        colorjitter=False,
        temp_jitter=True,
        center_crop=False,
        target_fps=30,
        decode_audio=True,
        aug_audio=[],
        num_sec=1,
        aud_sample_rate=48000,
        aud_spec_type=1,
        use_volume_jittering=False,
        use_temporal_jittering=False,
        z_normalize=False
    ):
        # Only support train, val, and test mode.
        assert mode in [
            "train",
            "val",
            "test",
        ], "Split '{}' not supported for '{}'".format(mode, ds_name)
        self.ds_name = ds_name
        self.mode = mode
        self.num_frames = num_frames
        self.sample_rate = sample_rate
        self.train_crop_size = train_crop_size
        self.test_crop_size = test_crop_size
        if train_crop_size == 112:
            train_jitter_scles = (128, 160)
        else:
            train_jitter_scles = (256, 320)
        self.train_jitter_scles = train_jitter_scles
        self.num_ensemble_views = num_ensemble_views
        self.num_spatial_crops = num_spatial_crops
        self.num_train_clips = num_train_clips
        self.data_prefix = os.path.join(ROOT_DIR[ds_name], MODE_DIR[ds_name][mode])
        self.path_to_data_dir = path_to_data_dir
        self.num_data_samples = num_data_samples
        self.colorjitter = colorjitter
        self.temp_jitter = temp_jitter
        self.center_crop = center_crop
        self.target_fps = target_fps
        self.decode_audio = decode_audio,
        self.aug_audio = aug_audio
        self.num_sec=num_sec
        self.aud_sample_rate = aud_sample_rate
        self.aud_spec_type = aud_spec_type
        self.use_volume_jittering = use_volume_jittering
        self.use_temporal_jittering = use_temporal_jittering
        self.z_normalize = z_normalize
        self._video_meta = {}

        # Get classes
        if self.ds_name != 'audioset':
            classes = list(sorted(glob.glob(os.path.join(self.data_prefix, '*'))))
            classes = [os.path.basename(i) for i in classes]
            self.class_to_idx = {classes[i]: i for i in range(len(classes))}

        # For training or validation mode, one single clip is sampled from every video. 
        # For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every video. 
        # For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames.
        if self.mode in ["train", "val"]:
            self._num_clips = self.num_train_clips
        elif self.mode in ["test"]:
            self._num_clips = (
                self.num_ensemble_views * self.num_spatial_crops
            )

        self.manager = Manager()
        print(f"Constructing {self.ds_name} {self.mode}...")
        self._construct_loader()

    def _construct_loader(self):
        """
        Construct the video loader.
        """
        # Get list of paths
        os.makedirs(self.path_to_data_dir, exist_ok=True)
        path_to_file = os.path.join(
            self.path_to_data_dir, f"{self.ds_name}_{self.mode}.txt"
        )
        if not os.path.exists(path_to_file) and self.ds_name != 'audioset':
            files = list(sorted(glob.glob(os.path.join(self.data_prefix, '*', '*')))) 
            with open(path_to_file, 'w') as f:
                for item in files:
                    f.write("%s\n" % item)

        # Get list of indices and labels
        self._path_to_videos = []
        self._labels = []
        self._spatial_temporal_idx = []
        self._vid_indices = []
        with open(path_to_file, "r") as f:
            for clip_idx, path in enumerate(f.read().splitlines()):
                for idx in range(self._num_clips):
                    self._path_to_videos.append(
                        os.path.join(self.data_prefix, path)
                    )
                    if self.ds_name != 'audioset':
                        class_name = path.split('/')[-2]
                        label = self.class_to_idx[class_name]
                    self._labels.append(int(label))
                    self._spatial_temporal_idx.append(idx)
                    self._vid_indices.append(clip_idx)
                    self._video_meta[clip_idx * self._num_clips + idx] = {}
        assert (
            len(self._path_to_videos) > 0
        ), "Failed to load {} split {} from {}".format(
            self.ds_name, self._split_idx, path_to_file
        )
        print(
            "Constructing {} dataloader (size: {}) from {}".format(
                self.ds_name, len(self._path_to_videos), path_to_file
            )
        )

        # Create / Load valid indices (has audio)
        vid_valid_file = f'{self.path_to_data_dir}/{self.ds_name}_valid.pkl'
        if os.path.exists(vid_valid_file):
            with open(vid_valid_file, 'rb') as handle:
                self.valid_indices = pickle.load(handle)
        else:
            self.valid_indices = filter_videos(self._path_to_videos)
            with open(vid_valid_file, 'wb') as handle:
                pickle.dump(
                    self.valid_indices, 
                    handle, 
                    protocol=pickle.HIGHEST_PROTOCOL
                )
        if self.num_data_samples is not None:
            self.valid_indices = self.valid_indices[:self.num_data_samples]
        print(f"Total number of videos: {len(self._path_to_videos)}, Valid videos: {len(self.valid_indices)}", flush=True)

        # Make lists a Manager objects
        self._path_to_videos = self.manager.list(self._path_to_videos)
        self.valid_indices = self.manager.list(self.valid_indices)


    def __getitem__(self, index):
        index_capped = index
        index = self.valid_indices[index_capped]
        if self.mode in ["train", "val"]:
            # -1 indicates random sampling.
            temporal_sample_index = -1
            spatial_sample_index = -1
            min_scale = self.train_jitter_scles[0]
            max_scale = self.train_jitter_scles[1]
            crop_size = self.train_crop_size
            if self.center_crop:
                spatial_sample_index = 1
                min_scale = self.train_crop_size
                max_scale = self.train_crop_size
                crop_size = self.train_crop_size
        elif self.mode in ["test"]:
            temporal_sample_index = (
                self._spatial_temporal_idx[index] // self.num_spatial_crops
            )
            # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
            # center, or right if width is larger than height, and top, middle,
            # or bottom if height is larger than width.
            spatial_sample_index = (
                self._spatial_temporal_idx[index] % self.num_spatial_crops
            )
            min_scale, max_scale, crop_size = [self.test_crop_size] * 3
            # The testing is deterministic and no jitter should be performed.
            # min_scale, max_scale, and crop_size are expect to be the same.
            assert len({min_scale, max_scale, crop_size}) == 1
        else:
            raise NotImplementedError(
                "Does not support {} mode".format(self.mode)
            )

        # Try to decode and sample a clip from a video. 
        video_container = get_video_container(
            self._path_to_videos[index],
            ENABLE_MULTI_THREAD_DECODE,
            DECODING_BACKEND,
        )

        # Decode video. Meta info is used to perform selective decoding.
        frames, spec = decode(
            self._path_to_videos[index],
            video_container,
            self.sample_rate,
            self.num_frames,
            temporal_sample_index if self.temp_jitter else 500,
            self.num_ensemble_views if self.temp_jitter else 1000,
            video_meta=self._video_meta[index],
            target_fps=self.target_fps,
            backend=DECODING_BACKEND,
            max_spatial_scale=max_scale,
            decode_audio=self.decode_audio,
            aug_audio=self.aug_audio,
            num_sec=self.num_sec,
            aud_sample_rate=self.aud_sample_rate,
            aud_spec_type=self.aud_spec_type,
            use_volume_jittering=self.use_volume_jittering,
            use_temporal_jittering=self.use_temporal_jittering,
            z_normalize=self.z_normalize,
        )

        frames = frames.float()
        frames = frames / 255.0

        # T H W C -> C T H W.
        frames = frames.permute(3, 0, 1, 2)
        # Perform data augmentation.
        frames = self.spatial_sampling(
            frames,
            spatial_idx=spatial_sample_index,
            min_scale=min_scale,
            max_scale=max_scale,
            crop_size=crop_size,
        )

        if self.colorjitter:
            frames = color_jitter(frames, 0.4, 0.4, 0.4)

        # Perform color normalization.
        frames = frames - torch.tensor(MEAN).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        frames = frames / torch.tensor(STD).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        label = self._labels[index]
        vid_idx = self._vid_indices[index]
        idx = index
        audio = spec
        if self.decode_audio:
            return frames, audio, label, vid_idx, index_capped
        else:
            return frames, label, vid_idx, index_capped

    def __len__(self):
        """
        Returns:
            (int): the number of videos in the dataset.
        """
        return len(self.valid_indices)

    def spatial_sampling(
        self,
        frames,
        spatial_idx=-1,
        min_scale=256,
        max_scale=320,
        crop_size=224,
    ):
        """
        Perform spatial sampling on the given video frames. If spatial_idx is
        -1, perform random scale, random crop, and random flip on the given
        frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
        with the given spatial_idx.
        Args:
            frames (tensor): frames of images sampled from the video. The
                dimension is `num frames` x `height` x `width` x `channel`.
            spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
                or 2, perform left, center, right crop if width is larger than
                height, and perform top, center, buttom crop if height is larger
                than width.
            min_scale (int): the minimal size of scaling.
            max_scale (int): the maximal size of scaling.
            crop_size (int): the size of height and width used to crop the
                frames.
        Returns:
            frames (tensor): spatially sampled frames.
        """
        assert spatial_idx in [-1, 0, 1, 2]
        if spatial_idx == -1:
            frames, _ = random_short_side_scale_jitter(
                frames, min_scale, max_scale
            )
            frames, _ = random_crop(frames, crop_size)
            frames, _ = horizontal_flip(0.5, frames)
        else:
            frames, _ = random_short_side_scale_jitter(frames, min_scale, max_scale)
            frames, _ = uniform_crop(frames, crop_size, spatial_idx)
        return frames


class AVideoDataset2(torch.utils.data.Dataset):
    """
    Audio-video loader. Construct the video loader, then sample
    clips from the videos. For training and validation, a single clip is
    randomly sampled from every video with random cropping, scaling, and
    flipping. For testing, multiple clips are uniformaly sampled from every
    video with uniform cropping. For uniform cropping, we take the left, center,
    and right crop if the width is larger than height, or take top, center, and
    bottom crop if the height is larger than the width.
    """

    def __init__(
        self,
        ds_name='kinetics',
        mode='train',
        num_frames=30,
        sample_rate=1,
        train_crop_size=112,
        path_to_data_dir='datasets/data',
        seed=0,
        num_ensemble_views=10,
        num_spatial_crops=1,
        num_data_samples=None,
        colorjitter=False,
        synced=True,
        target_fps=30,
        decode_audio=True,
        aug_audio=[],
        num_sec=1,
        aud_sample_rate=48000,
        aud_spec_type=1,
        use_volume_jittering=False,
        use_temporal_jittering=False,
        z_normalize=False,
    ):
        # Only support train, val, and test mode.
        assert mode in [
            "train",
            "val",
            "test",
        ], "Split '{}' not supported for {}".format(mode, ds_name)
        self.ds_name = ds_name
        self.mode = mode
        self.num_frames = num_frames
        self.sample_rate = sample_rate
        self.train_crop_size = train_crop_size
        if train_crop_size == 112:
            train_jitter_scles = (128, 160)
        else:
            train_jitter_scles = (256, 320)
        self.train_jitter_scles = train_jitter_scles
        self.num_ensemble_views = num_ensemble_views
        self.num_spatial_crops = num_spatial_crops
        self.data_prefix = os.path.join(ROOT_DIR[ds_name], MODE_DIR[ds_name][mode])
        self.path_to_data_dir = path_to_data_dir
        self.num_data_samples = num_data_samples
        self.colorjitter = colorjitter
        self.sync = synced
        self.target_fps = target_fps
        self.decode_audio = decode_audio,
        self.aug_audio = aug_audio
        self.num_sec=num_sec
        self.aud_sample_rate = aud_sample_rate
        self.aud_spec_type = aud_spec_type
        self.use_volume_jittering = use_volume_jittering
        self.use_temporal_jittering = use_temporal_jittering
        self.z_normalize = z_normalize

        self._video_meta = {}

        # Get classes
        if self.ds_name != 'audioset':
            classes = list(sorted(glob.glob(os.path.join(self.data_prefix, '*'))))
            classes = [os.path.basename(i) for i in classes]
            self.class_to_idx = {classes[i]: i for i in range(len(classes))}

        # For training or validation mode, one single clip is sampled from every video.
        # For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every video.
        # For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames.
        if self.mode in ["train", "val"]:
            self._num_clips = 1
        elif self.mode in ["test"]:
            self._num_clips = (
                self.num_ensemble_views * self.num_spatial_crops
            )

        self.manager = Manager()
        print(f"Constructing {self.ds_name} {self.mode}...")
        self._construct_loader()

    def _construct_loader(self):
        """
        Construct the video loader.
        """
        # Get list of paths
        os.makedirs(self.path_to_data_dir, exist_ok=True)
        path_to_file = os.path.join(
            self.path_to_data_dir, f"{self.ds_name}_{self.mode}.txt"
        )
        if not os.path.exists(path_to_file) and self.ds_name != 'audioset':
            files = list(sorted(glob.glob(os.path.join(self.data_prefix, '*', '*')))) 
            with open(path_to_file, 'w') as f:
                for item in files:
                    f.write("%s\n" % item)

        self._path_to_videos = []
        self._labels = []
        self._spatial_temporal_idx = []
        self._vid_indices = []
        with open(path_to_file, "r") as f:
            for clip_idx, path in enumerate(f.read().splitlines()):
                for idx in range(self._num_clips):
                    self._path_to_videos.append(
                        os.path.join(self.data_prefix, path)
                    )
                    if self.ds_name != 'audioset':
                        class_name = path.split('/')[-2]
                        label = self.class_to_idx[class_name]
                    self._labels.append(int(label))
                    self._spatial_temporal_idx.append(idx)
                    self._vid_indices.append(clip_idx)
                    self._video_meta[clip_idx * self._num_clips + idx] = {}
        assert (
            len(self._path_to_videos) > 0
        ), "Failed to load {} split {} from {}".format(
            self.ds_name, self._split_idx, path_to_file
        )
        print(
            "Constructing {} dataloader (size: {}) from {}".format(
                self.ds_name, len(self._path_to_videos), path_to_file
            )
        )

        # Create / Load valid indices (has audio)
        vid_valid_file = f'{self.path_to_data_dir}/{self.ds_name}_valid.pkl'
        if os.path.exists(vid_valid_file):
            with open(vid_valid_file, 'rb') as handle:
                self.valid_indices = pickle.load(handle)
        else:
            self.valid_indices = filter_videos(self._path_to_videos)
            with open(vid_valid_file, 'wb') as handle:
                pickle.dump(
                    self.valid_indices, 
                    handle, 
                    protocol=pickle.HIGHEST_PROTOCOL
                )
        if self.num_data_samples is not None:
            self.valid_indices = self.valid_indices[:self.num_data_samples]
        print(f"Total number of videos: {len(self._path_to_videos)}, Valid videos: {len(self.valid_indices)}", flush=True)

        # Make lists a Manager objects
        self._path_to_videos = self.manager.list(self._path_to_videos)
        self.valid_indices = self.manager.list(self.valid_indices)

    def __getitem__(self, index):
        """
        Given the video index, return tensors: video, audio, label, vid_idx, idx
        Otherwise, repeatly find a random video that can be decoded as a replacement.
        Args:
            index (int): the video index provided by the pytorch sampler.
        Returns:
            frames (tensor): the frames of sampled from the video. The dimension
                is `channel` x `num frames` x `height` x `width`.
            label (int): the label of the current video.
            index (int): if the video provided by pytorch sampler can be
                decoded, then return the index of the video. If not, return the
                index of the video replacement that can be decoded.
        """
        index_capped = index
        index = self.valid_indices[index_capped]
        if self.mode in ["train", "val"]:
            # -1 indicates random sampling.
            clip_idx1 = random.randint(0, 1000)
            clip_idx2 = clip_idx1 if self.sync else random.randint(0, 1000)
            clip_idx2 = clip_idx2 + 1000 if clip_idx2 < 0 else clip_idx2
            num_clips = 1000
            spatial_sample_index = -1
            min_scale = self.train_jitter_scles[0]
            max_scale = self.train_jitter_scles[1]
            crop_size = self.train_crop_size
        elif self.mode in ["test"]:
            temporal_sample_index = (
                self._spatial_temporal_idx[index] // self.num_spatial_crops
            )
            # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
            # center, or right if width is larger than height, and top, middle,
            # or bottom if height is larger than width.
            spatial_sample_index = (
                self._spatial_temporal_idx[index] % self.num_spatial_crops
            )
            min_scale, max_scale, crop_size = [self.test_crop_size] * 3
            # The testing is deterministic and no jitter should be performed.
            # min_scale, max_scale, and crop_size are expect to be the same.
            assert len({min_scale, max_scale, crop_size}) == 1
        else:
            raise NotImplementedError(
                "Does not support {} mode".format(self.mode)
            )

        video_container = get_video_container(
            self._path_to_videos[index],
            ENABLE_MULTI_THREAD_DECODE,
            DECODING_BACKEND,
        )
        ################# sample twice from same video with potential different starting time #########################
        #
        V = []
        A = []
        frames1, spec1 = decode(
            self._path_to_videos[index],
            video_container,
            self.sample_rate,
            self.num_frames,
            clip_idx1,
            num_clips=num_clips,
            video_meta=self._video_meta[index],
            target_fps=self.target_fps,
            backend=DECODING_BACKEND,
            max_spatial_scale=max_scale,
            decode_audio=self.decode_audio,
            aug_audio=self.aug_audio,
            num_sec=self.num_sec,
            aud_sample_rate=self.aud_sample_rate,
            aud_spec_type=self.aud_spec_type,
            use_volume_jittering=self.use_volume_jittering,
            use_temporal_jittering=self.use_temporal_jittering,
            z_normalize=self.z_normalize,
        )            

        # Normalization
        frames1 = frames1.float()
        frames1 = frames1 / 255.0

        # T H W C -> C T H W.
        frames1 = frames1.permute(3, 0, 1, 2)
        # Perform data augmentation.
        frames1 = self.spatial_sampling(
            frames1,
            spatial_idx=spatial_sample_index,
            min_scale=min_scale,
            max_scale=max_scale,
            crop_size=crop_size,
        )

        if self.colorjitter:
            frames1 = color_jitter(frames1, 0.4, 0.4, 0.4)

        # Perform color normalization.
        frames1 = frames1 - torch.tensor(MEAN).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        frames1 = frames1 / torch.tensor(STD).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        # clipidx2
        video_container = get_video_container(
            self._path_to_videos[index],
            ENABLE_MULTI_THREAD_DECODE,
            DECODING_BACKEND,
        )
        frames2, spec2 = decode(
            self._path_to_videos[index],
            video_container,
            self.sample_rate,
            self.num_frames,
            clip_idx2,
            num_clips=num_clips,
            video_meta=self._video_meta[index],
            target_fps=self.target_fps,
            backend=DECODING_BACKEND,
            max_spatial_scale=max_scale,
            decode_audio=self.decode_audio,
            aug_audio=self.aug_audio,
            num_sec=self.num_sec,
            aud_sample_rate=self.aud_sample_rate,
            aud_spec_type=self.aud_spec_type,
            use_volume_jittering=self.use_volume_jittering,
            use_temporal_jittering=self.use_temporal_jittering,
            z_normalize=self.z_normalize,
        )

        # Normalization
        frames2 = frames2.float()
        frames2 = frames2 / 255.0

        # T H W C -> C T H W.
        frames2 = frames2.permute(3, 0, 1, 2)
        
        # Perform data augmentation.
        frames2 = self.spatial_sampling(
            frames2,
            spatial_idx=spatial_sample_index,
            min_scale=min_scale,
            max_scale=max_scale,
            crop_size=crop_size,
        )

        if self.colorjitter:
            frames2 = color_jitter(frames2, 0.4, 0.4, 0.4)

        # Perform color normalization.
        frames2 = frames2 - torch.tensor(MEAN).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        frames2 = frames2 / torch.tensor(STD).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        label = self._labels[index]
        vid_idx = self._vid_indices[index]	
        idx = index

        return torch.cat([frames1, frames2], dim=0), torch.cat([spec1, spec2], dim=0), label, vid_idx, index_capped


    def __len__(self):
        """
        Returns:
            (int): the number of videos in the dataset.
        """
        return len(self.valid_indices)

    def spatial_sampling(
        self,
        frames,
        spatial_idx=-1,
        min_scale=256,
        max_scale=320,
        crop_size=224,
    ):
        """
        Perform spatial sampling on the given video frames. If spatial_idx is
        -1, perform random scale, random crop, and random flip on the given
        frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
        with the given spatial_idx.
        Args:
            frames (tensor): frames of images sampled from the video. The
                dimension is `num frames` x `height` x `width` x `channel`.
            spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
                or 2, perform left, center, right crop if width is larger than
                height, and perform top, center, buttom crop if height is larger
                than width.
            min_scale (int): the minimal size of scaling.
            max_scale (int): the maximal size of scaling.
            crop_size (int): the size of height and width used to crop the
                frames.
        Returns:
            frames (tensor): spatially sampled frames.
        """
        assert spatial_idx in [-1, 0, 1, 2]
        if spatial_idx == -1:
            frames, _ = random_short_side_scale_jitter(
                frames, min_scale, max_scale
            )
            frames, _ = random_crop(frames, crop_size)
            frames, _ = horizontal_flip(0.5, frames)
        else:
            # The testing is deterministic and no jitter should be performed.
            # min_scale, max_scale, and crop_size are expect to be the same.
            assert len({min_scale, max_scale, crop_size}) == 1
            frames, _ = random_short_side_scale_jitter(
                frames, min_scale, max_scale
            )
            frames, _ = uniform_crop(frames, crop_size, spatial_idx)
        return frames



if __name__ == '__main__':

    import random
    import time
    from torch.utils.data import DataLoader
    from torch.utils.data.dataloader import default_collate
    import torchvision
    import torch

    print("="*60)
    print('Testing AVideoDataset2')
    val_dataset = AVideoDataset2(
        ds_name='kinetics',
        mode='train',
        colorjitter=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=16,
        num_workers=0,
        shuffle=True,
        collate_fn=None
    )
    tic = time.time()
    for batch_idx, batch in enumerate(val_loader):
        video, spec, label, vid_idx, idx = batch
        print(len(video),video[0].size(),flush=True)
        print(video.size(), spec.size(),flush=True)
        video1, video2 = torch.split(video, [3, 3], dim=1)
        audio1, audio2 = torch.split(spec, [1, 1], dim=1)
        print(
            batch_idx,
            video1.size(),
            audio1.size(),
            label,
            idx,
            vid_idx,
            time.time() - tic
        )
        print(f'Batch time (s): {time.time() - tic}')
        tic = time.time()
