import cv2
import glob
import json
import numpy as np
import os, sys, traceback
import pandas as pd
import random
import time
import torch
import decord
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as trans
from torchvision.transforms import v2
import zipfile

import config as cfg


decord.bridge.set_bridge('torch')

import matplotlib.pyplot as plt


# Baseline training dataloader.
class baseline_train_dataloader(Dataset):

    def __init__(self, params, dataset='ucf101', shuffle=True, data_percentage=1.0, split=1):
        # Spported datasets: ucf101, hmdb51, k400
        self.dataset = dataset
        self.params = params
        
        if self.dataset == 'ucf101':
            if split <= 3:
                all_paths = open(os.path.join(cfg.ucf101_path, 'ucfTrainTestlist', f'trainlist0{split}.txt'),'r').read().splitlines()
                self.all_paths = [x.replace('/', os.sep) for x in all_paths]
            else:
                print(f'Invalid split input: {split}')
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']

        elif self.dataset == 'hmdb51':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_path, 'hmdb51_train_labels.csv'), index_col=None)
            self.all_paths = [f'{os.path.join(cfg.hmdb_path, 'Videos', c, f)} {l}' for c, f, l in zip(anno_file['class'].to_list(), anno_file['filename'].to_list(), anno_file['label'].to_list())]

        elif self.dataset == 'k400':
            self.all_paths = open(os.path.join(cfg.kinetics_path, 'annotation_train_fullpath_resizedvids.txt'), 'r').read().splitlines()
            
        else:
            print(f'{self.dataset} does not exist.')
            
        self.shuffle = shuffle

        if self.shuffle:
            random.shuffle(self.all_paths)
        
        self.data_percentage = data_percentage
        self.data_limit = int(len(self.all_paths)*self.data_percentage)
        self.data = self.all_paths[0: self.data_limit]
        self.augmentation = v2.Compose([
            v2.Resize(size=256, antialias=True),
            v2.RandomResizedCrop(size=(224, 224), antialias=True),
            v2.RandomHorizontalFlip(p=0.5),
            v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)
            
    def __getitem__(self,index):        
        clip, label, vid_path, frame_list = self.process_data(index)
        return clip, label, vid_path, frame_list


    def process_data(self, idx):
        # Label building.
        if self.dataset == 'ucf101':
            vid_path = os.path.join(cfg.ucf101_path, 'Videos', self.data[idx].split(' ')[0])
            label = int(self.classes[vid_path.split(os.sep)[-2]]) - 1  # This element should be activity name.
        elif self.dataset == 'hmdb51':
            vid_path, label = self.data[idx].split(' ')
            label = int(label)
        elif self.dataset == 'k400':
            vid_path = self.data[idx].split(' ')[0]
            label = int(self.data[idx].split(' ')[1]) - 1 

        # Clip building.
        clip, frame_list = self.build_clip(vid_path)

        return clip, label, vid_path, frame_list
    
    def build_clip(self, vid_path):
        frame_count = -1
        try:
            # Get the original video dimensions
            cap = cv2.VideoCapture(vid_path)
            original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            cap.release()

            # Calculate the aspect ratio
            aspect_ratio = original_width / original_height

            # Determine the new dimensions based on the short side
            if original_height < original_width:
                new_height = 256
                new_width = int(new_height * aspect_ratio)
            else:
                new_width = 256
                new_height = int(new_width / aspect_ratio)

            vr = decord.VideoReader(vid_path, width=new_width, height=new_height, ctx=decord.cpu())
            frame_count = len(vr)

            skip_frames_full = self.params.fix_skip

            left_over = frame_count - skip_frames_full*self.params.num_frames

            if left_over > 0:
                start_frame_full = np.random.randint(0, int(left_over)) 
            else:
                skip_frames_full /= 2
                left_over = frame_count - skip_frames_full*self.params.num_frames
                if left_over > 0:
                    start_frame_full = np.random.randint(0, int(left_over))
                else:
                    start_frame_full = 0

            frames_full = start_frame_full + np.asarray([int(int(skip_frames_full)*f) for f in range(self.params.num_frames)])

            # Some edge case fixing.
            if frames_full[-1] >= frame_count:
                frames_full[-1] = int(frame_count-1)
            
            full_clip = []
            list_full = torch.from_numpy(frames_full)
            try:
                frames = vr.get_batch(frames_full)
            except:
                frames = vr.get_batch(frames_full[:int(left_over)])

            self.ori_reso_h, self.ori_reso_w = frames.shape[1:3]
            self.min_size = min(self.ori_reso_h, self.ori_reso_w)
            for frame in frames:
                full_clip.append(frame.permute(2, 0, 1))
            if len(frames) < self.params.num_frames:
                full_clip = full_clip + [full_clip[-1] for _ in range(self.params.num_frames - len(frames))]

            full_clip = self.augmentation(torch.stack(full_clip, dim=0))

            return full_clip, list_full
        except:
            # traceback.print_exc()
            # print(f'Clip {vid_path} Failed, frame_count {frame_count}.')
            return None, None


# Validation dataset.
class baseline_val_dataloader(Dataset):

    def __init__(self, params, dataset='ucf101', shuffle=True, data_percentage=1.0, mode=0, split=1, total_num_modes=5, threeCrop=False):
        
        self.total_num_modes = total_num_modes
        if self.total_num_modes == 1:
            self.total_num_modes = 5
            self.mode = 2
        self.params = params
        self.dataset = dataset
        self.threecrop = threeCrop

        if self.dataset == 'ucf101':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            if split <= 3:
                all_paths = open(os.path.join(cfg.ucf101_path, 'ucfTrainTestlist', f'testlist0{split}.txt'),'r').read().splitlines()
                self.all_paths = [x.replace('/', os.sep) for x in all_paths]
            else:
                print(f'Invalid split input: {split}')    
        
        elif self.dataset == 'hmdb51':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_path, 'hmdb51_test_labels.csv'), index_col=None)
            self.all_paths = [f'{os.path.join(cfg.hmdb_path, 'Videos', c, f)} {l}' for c, f, l in zip(anno_file['class'].to_list(), anno_file['filename'].to_list(), anno_file['label'].to_list())]
                
        elif self.dataset == 'k400':
            self.all_paths = open(os.path.join(cfg.kinetics_path, 'annotation_val_fullpath_resizedvids.txt'),'r').read().splitlines()

        elif self.dataset == 'ucf101_scuba_places365':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_SCUBA_Places365', '*'), recursive=False))
        
        elif self.dataset == 'ucf101_scuba_stripe':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_SCUBA_Stripe', '*'), recursive=False))

        elif self.dataset == 'ucf101_scuba_vqgan':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_SCUBA_VQGAN', '*'), recursive=False))
 
        elif self.dataset == 'ucf101_conflfg_stripe':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_ConflFG_Stripe', '*'), recursive=False))

        elif self.dataset == 'ucf101_scufo_places365':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_SCUFO_Places365', '*'), recursive=False))
        
        elif self.dataset == 'ucf101_scufo_stripe':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_SCUFO_Stripe', '*'), recursive=False))

        elif self.dataset == 'ucf101_scufo_vqgan':
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'UCF101_SCUFO_VQGAN', '*'), recursive=False))

        elif self.dataset == 'hmdb51_scuba_places365':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_SCUBA_Places365', '*'), recursive=False))
        
        elif self.dataset == 'hmdb51_scuba_stripe':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_SCUBA_Stripe', '*'), recursive=False))

        elif self.dataset == 'hmdb51_scuba_vqgan':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_SCUBA_VQGAN', '*'), recursive=False))

        elif self.dataset == 'hmdb51_conflfg_stripe':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_ConflFG_Stripe', '*'), recursive=False))

        elif self.dataset == 'hmdb51_scufo_places365':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_SCUFO_Places365', '*'), recursive=False))

        elif self.dataset == 'hmdb51_scufo_stripe':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_SCUFO_Stripe', '*'), recursive=False))

        elif self.dataset == 'hmdb51_scufo_vqgan':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_file_path, 'hmdb51_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'HMDB51_SCUFO_VQGAN', '*'), recursive=False))

        elif self.dataset == 'k400_scuba_places365':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_SCUBA_Places365', '*'), recursive=False))
        
        elif self.dataset == 'k400_scuba_stripe':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_SCUBA_Stripe', '*'), recursive=False))

        elif self.dataset == 'k400_scuba_vqgan':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_SCUBA_VQGAN', '*'), recursive=False))

        elif self.dataset == 'k400_conflfg_stripe':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_ConflFG_Stripe', '*'), recursive=False))

        elif self.dataset == 'k400_scufo_places365':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_SCUFO_Places365', '*'), recursive=False))

        elif self.dataset == 'k400_scufo_stripe':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_SCUFO_Stripe', '*'), recursive=False))

        elif self.dataset == 'k400_scufo_vqgan':
            anno_file = pd.read_csv(os.path.join(cfg.kinetics_path, 'k400_val_labels.csv'), index_col=None)
            self.classes = {os.path.basename(k): v for k, v in zip(anno_file['filename'].to_list(), anno_file['label'].to_list())}
            self.all_paths = sorted(glob.glob(os.path.join(cfg.bias_path, 'K400_SCUFO_VQGAN', '*'), recursive=False))

        else:
            print(f'{self.dataset} does not exist.')
                
        self.shuffle = shuffle

        if self.shuffle:
            random.shuffle(self.all_paths)
        
        self.data_percentage = data_percentage
        self.data_limit = int(len(self.all_paths)*self.data_percentage)
        self.data = self.all_paths[0: self.data_limit]
        self.mode = mode

        self.augmentation = v2.Compose([
            v2.Resize(size=256, antialias=True),
            v2.CenterCrop(size=(224, 224)),
            v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.data)
            
    def __getitem__(self,index):        
        clip, label, vid_path, frame_list = self.process_data(index)
        return clip, label, vid_path, frame_list


    def process_data(self, idx):
        # Label building.
        if self.dataset == 'ucf101':
            vid_path1 = os.path.join(cfg.ucf101_path, 'Videos', self.data[idx].split(' ')[0])
            label = int(self.classes[vid_path1.split(os.sep)[-2]]) - 1 # This element should be activity name.
        elif self.dataset == 'hmdb51':
            vid_path1, label = self.data[idx].split(' ')
            label = int(label)
        elif self.dataset == 'k400':
            vid_path1 = self.data[idx].split(' ')[0]
            label = int(self.data[idx].split(' ')[1]) - 1 
        elif self.dataset == 'aras':
            vid_name, label = self.data[idx]
            vid_path1 = os.path.join(cfg.bias_path, 'ARAS', vid_name)

        elif 'scuba' in self.dataset or 'conflfg' in self.dataset or 'scufo' in self.dataset:
            vid_path1 = self.data[idx]
            if 'ucf101' in self.dataset:
                label = int(self.classes[vid_path1.split(f'{os.sep}v_')[-1].split('_g0')[0]]) - 1
            elif 'hmdb51' in self.dataset:
                key = os.path.basename(vid_path1)[:-3] if 'scufo' in self.dataset else os.path.basename(vid_path1)[:-7]
                label = self.classes[key + '.avi']
            elif 'k400' in self.dataset:
                key = os.path.basename(vid_path1) if 'scufo' in self.dataset else os.path.basename(vid_path1)[:-4]
                label = self.classes[key]
            else:
                label = None  # TODO: add label for other datasets
            clip, frame_list = self.build_clip_scuba(vid_path1)
            return clip, label, frame_list, idx

        # Clip building.
        clip, frame_list = self.build_clip(vid_path1)
        return clip, label, frame_list, idx # os.path.basename(vid_path1)
    
    def build_clip_scuba(self, vid_path):
        try:
            frame_list = sorted(glob.glob(os.path.join(vid_path, '*.jpg')))
            if 'scufo' in self.dataset:
                frame = torchvision.io.read_image(frame_list[0])
                _, self.ori_reso_h, self.ori_reso_w = frame.shape
                self.min_size = min(self.ori_reso_h, self.ori_reso_w)
                frame = self.augmentation(frame)
                list_full = [0 for _ in range(self.params.num_frames)]
                full_clip = []
                for i in range(self.params.num_frames):
                    full_clip.append(frame)
            else:
                return self.build_clip(vid_path)
                # if len(frame_list) == 0:
                with zipfile.ZipFile(vid_path, 'r') as z:
                    frame_list = sorted(z.namelist())

                    frame_count = len(frame_list)
                    skip_frames_full = self.params.fix_skip 

                    if skip_frames_full*self.params.num_frames > frame_count:
                        skip_frames_full /= 2

                    left_over = skip_frames_full*self.params.num_frames
                    F = frame_count - left_over

                    start_frame_full = 0 + int(np.linspace(0,F-10, self.total_num_modes)[self.mode])

                    if start_frame_full< 0:
                        start_frame_full = self.mode

                    list_full = []

                    list_full = start_frame_full + np.asarray(
                        [int(int(skip_frames_full) * f) for f in range(self.params.num_frames)])
                    
                    # set all values greater than frame_count to frame_count
                    list_full = np.minimum(list_full, frame_count-1)
                    # list_full = list_full[:frame_count]

                    full_clip = []

                    # for i, frame_idx in enumerate(list_full):
                    #     frame = torchvision.io.read_image(frame_list[frame_idx])
                    #     if i == 0:
                    #         _, self.ori_reso_h, self.ori_reso_w = frame.shape
                    #         self.min_size = min(self.ori_reso_h, self.ori_reso_w)
                    #     full_clip.append(self.augmentation(frame))
                # with zipfile.ZipFile(vid_path, 'r') as z:
                    for i, frame_idx in enumerate(list_full):
                        frame = torchvision.io.decode_image(torch.frombuffer(bytearray(z.read(frame_list[frame_idx])), dtype=torch.uint8))
                        if i == 0:
                            _, self.ori_reso_h, self.ori_reso_w = frame.shape
                            self.min_size = min(self.ori_reso_h, self.ori_reso_w)
                        full_clip.append(self.augmentation(frame))

            return full_clip, torch.tensor(list_full)
        except:
            print(f'Clip {vid_path} Failed.')
            # traceback.print_exc()
            return None, None
   
    def build_clip(self, vid_path):
        frame_count = -1
        try:
            # Get the original video dimensions
            cap = cv2.VideoCapture(vid_path)
            original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            cap.release()

            # Calculate the aspect ratio
            aspect_ratio = original_width / original_height

            # Determine the new dimensions based on the short side
            if original_height < original_width:
                new_height = 256
                new_width = int(new_height * aspect_ratio)
            else:
                new_width = 256
                new_height = int(new_width / aspect_ratio)

            vr = decord.VideoReader(vid_path, width=new_width, height=new_height, ctx=decord.cpu())

            frame_count = len(vr)
            skip_frames_full = self.params.fix_skip 

            if skip_frames_full*self.params.num_frames > frame_count:
                skip_frames_full /= 2

            left_over = skip_frames_full*self.params.num_frames
            F = frame_count - left_over

            start_frame_full = 0 + int(np.linspace(0,F-10, self.total_num_modes)[self.mode])

            if start_frame_full < 0:
                start_frame_full = 0#self.mode

            full_clip_frames = []

            full_clip_frames = start_frame_full + np.asarray(
                [int(int(skip_frames_full) * f) for f in range(self.params.num_frames)])

            full_clip = []
            list_full = torch.from_numpy(full_clip_frames)
            try:
                frames = vr.get_batch(full_clip_frames)
            except:
                frames = vr.get_batch(full_clip_frames[:int(F)])

            self.ori_reso_w, self.ori_reso_h = frames.shape[1:3]
            self.min_size = min(self.ori_reso_h, self.ori_reso_w)
            for frame in frames:
                full_clip.append(self.augmentation(frame.permute(2, 0, 1)))

            if len(frames) < self.params.num_frames:
                full_clip = full_clip + [full_clip[-1] for _ in range(self.params.num_frames - len(frames))]

            return full_clip, list_full
        except:
            # traceback.print_exc()
            # print(f'Clip {vid_path} Failed, frame_count {frame_count}.')
            return None, None


# Baseline training dataloader.
class spatial_train_dataloader(Dataset):

    def __init__(self, params, dataset='ucf101', shuffle=True, data_percentage=1.0, split=1):
        # Spported datasets: ucf101, hmdb51, k400
        self.dataset = dataset
        self.params = params
        
        if self.dataset == 'ucf101':
            if split <= 3:
                all_paths = open(os.path.join(cfg.ucf101_path, 'ucfTrainTestlist', f'trainlist0{split}.txt'),'r').read().splitlines()
                self.all_paths = [x.replace('/', os.sep) for x in all_paths]
            else:
                print(f'Invalid split input: {split}')
            self.classes = json.load(open(cfg.ucf101_class_mapping))['classes']

        elif self.dataset == 'hmdb51':
            anno_file = pd.read_csv(os.path.join(cfg.hmdb_path, 'hmdb51_train_labels.csv'), index_col=None)
            self.all_paths = [f'{os.path.join(cfg.hmdb_path, 'Videos', c, f)} {l}' for c, f, l in zip(anno_file['class'].to_list(), anno_file['filename'].to_list(), anno_file['label'].to_list())]
            
        elif self.dataset == 'k400':
            self.all_paths = open(os.path.join(cfg.kinetics_path, 'annotation_train_fullpath_resizedvids.txt'), 'r').read().splitlines()

        else:
            print(f'{self.dataset} does not exist.')
            
        self.shuffle = shuffle

        if self.shuffle:
            random.shuffle(self.all_paths)
        
        self.data_percentage = data_percentage
        self.data_limit = int(len(self.all_paths)*self.data_percentage)
        self.data = self.all_paths[0: self.data_limit]
        self.augmentation = v2.Compose([
            v2.Resize(size=256, antialias=True),
            v2.RandomResizedCrop(size=(224, 224), antialias=True),
            v2.RandomHorizontalFlip(p=0.5),
            v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)
            
    def __getitem__(self,index):
        return self.process_data(index)

    def process_data(self, idx):
        # Label building.
        if self.dataset == 'ucf101':
            vid_path = os.path.join(cfg.ucf101_path, 'Videos', self.data[idx].split(' ')[0])
            label = int(self.classes[vid_path.split(os.sep)[-2]]) - 1 # This element should be activity name.
        elif self.dataset == 'hmdb51':
            vid_path, label = self.data[idx].split(' ')
            label = int(label)
        elif self.dataset == 'k400':
            vid_path = self.data[idx].split(' ')[0]
            label = int(self.data[idx].split(' ')[1]) - 1 

        # Clip building.
        temp_clip, temp_frame_list = self.build_clip(vid_path)

        if temp_clip is None:
            return None, None, None, None, None, None

        # Static clip building.
        frame_idx = np.random.randint(0, self.params.num_frames)
        spat_clip = torch.tile(temp_clip[frame_idx], (self.params.num_frames, 1, 1, 1))
        spat_frame_list = torch.as_tensor([temp_frame_list[frame_idx] for _ in range(self.params.num_frames)])

        return temp_clip, spat_clip, label, vid_path, temp_frame_list, spat_frame_list

    
    def build_clip(self, vid_path):
        frame_count = -1
        try:
            # Get the original video dimensions
            cap = cv2.VideoCapture(vid_path)
            original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            cap.release()

            # Calculate the aspect ratio
            aspect_ratio = original_width / original_height

            # Determine the new dimensions based on the short side
            if original_height < original_width:
                new_height = 256
                new_width = int(new_height * aspect_ratio)
            else:
                new_width = 256
                new_height = int(new_width / aspect_ratio)

            vr = decord.VideoReader(vid_path, width=new_width, height=new_height, ctx=decord.cpu())
            frame_count = len(vr)

            skip_frames_full = self.params.fix_skip

            left_over = frame_count - skip_frames_full*self.params.num_frames

            if left_over > 0:
                start_frame_full = np.random.randint(0, int(left_over)) 
            else:
                skip_frames_full /= 2
                left_over = frame_count - skip_frames_full*self.params.num_frames
                if left_over > 0:
                    start_frame_full = np.random.randint(0, int(left_over))
                else:
                    start_frame_full = 0

            frames_full = start_frame_full + np.asarray([int(int(skip_frames_full)*f) for f in range(self.params.num_frames)])

            # Some edge case fixing.
            if frames_full[-1] >= frame_count:
                frames_full[-1] = int(frame_count-1)
            
            full_clip = []
            list_full = torch.from_numpy(frames_full)
            try:
                frames = vr.get_batch(frames_full)
            except:
                frames = vr.get_batch(frames_full[:int(left_over)])

            for frame in frames:
                full_clip.append(frame.permute(2, 0, 1))

            if len(frames) < self.params.num_frames:
                full_clip = full_clip + [full_clip[-1] for _ in range(self.params.num_frames - len(frames))]

            full_clip = self.augmentation(torch.stack(full_clip, dim=0))

            return full_clip, list_full
        except:
            # traceback.print_exc()
            # print(f'Clip {vid_path} Failed, frame_count {frame_count}.')
            return None, None


def collate_fn_val(batch):
    f_clip, label, frame_list, idx = [], [], [], []
    for item in batch:
        if not (item[0] == None or item[1] == None or item[2] == None):
            f_clip.append(torch.stack(item[0],dim=0)) 
            label.append(item[1])
            frame_list.append(item[2])
            idx.append(item[3])

    f_clip = torch.stack(f_clip, dim=0)
    label = torch.tensor(label)
    frame_list = torch.stack(frame_list, dim=0)
    idx = torch.tensor(idx)
    
    return f_clip, label, frame_list, idx


def collate_fn_eval(batch):
    f_clip, label, frame_list, vid_path = [], [], [], []
    for item in batch:
        if not (item[0] == None or item[1] == None or item[2] == None):
            f_clip.append(torch.stack(item[0],dim=0)) 
            label.append(item[1])
            frame_list.append(item[2])
            vid_path.append(item[3])

    f_clip = torch.stack(f_clip, dim=0)
    label = torch.tensor(label)
    frame_list = torch.stack(frame_list, dim=0)
    
    return f_clip, label, frame_list, vid_path


def collate_fn_train(batch):
    f_clip, label, vid_path, frame_list = [], [], [], []
    for item in batch:
        if not (item[0] == None or item[1] == None or item[2] == None):
            # f_clip.append(torch.stack(item[0],dim=0))
            f_clip.append(item[0])
            label.append(item[1])
            vid_path.append(item[2])
            frame_list.append(item[3])

    f_clip = torch.stack(f_clip, dim=0)
    label = torch.tensor(label)
    frame_list = torch.stack(frame_list, dim=0)

    return f_clip, label, vid_path, frame_list


def collate_fn_spatial(batch):
    temp_clip, spat_clip, label, vid_path, temp_frame_list, spat_frame_list = [], [], [], [], [], []
    for item in batch:
        if not (item[0] == None or item[1] == None or item[2] == None or item[3] == None or item[4] == None or item[5] == None):
            # temp_clip.append(torch.stack(item[0],dim=0))
            temp_clip.append(item[0])
            spat_clip.append(item[1]) 
            label.append(item[2])
            vid_path.append(item[3])
            temp_frame_list.append(item[4])
            spat_frame_list.append(item[5])
    temp_clip = torch.stack(temp_clip, dim=0)
    spat_clip = torch.stack(spat_clip, dim=0)
    label = torch.tensor(label)
    temp_frame_list = torch.stack(temp_frame_list, dim=0)
    spat_frame_list = torch.stack(spat_frame_list, dim=0)
    return temp_clip, spat_clip, label, vid_path, temp_frame_list, spat_frame_list



if __name__ == '__main__':
    import params_debias as params

    # dataset = 'hmdb51_scuba_places365'
    dataset = 'hmdb51'

    # train_dataset = baseline_train_dataloader(params=params, dataset=dataset, shuffle=True, data_percentage=1.0)
    # train_dataloader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=False, collate_fn=collate_fn_train, num_workers=params.num_workers)
    
    train_dataset = spatial_train_dataloader(params=params, dataset=dataset, shuffle=True, data_percentage=1.0)
    train_dataloader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=False, collate_fn=collate_fn_spatial, num_workers=params.num_workers)

    val_dataset = baseline_val_dataloader(params=params, dataset=dataset, shuffle=True, data_percentage=1.0)
    val_dataloader = DataLoader(val_dataset, batch_size=params.batch_size, shuffle=False, collate_fn=collate_fn_val, num_workers=params.num_workers)

    print(f'Length of training dataset: {len(train_dataset)}')
    print(f'Steps involved: {len(train_dataset)/params.batch_size}')
    t = time.time()

    # for i, (clip, label, vid_path, frame_list) in enumerate(train_dataloader):
    for i, (clip, stat_clip, label, vid_path, frame_list, stat_frame_list) in enumerate(train_dataloader):
        if i % 10 == 0:
            print()
            print(f'Full_clip shape is {clip.shape}')
            print(f'Label is {label}', flush=True)
            print(time.time() - t)
            continue
            
    print(f'Time taken to load data is {time.time()-t}')
