import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset
import math


class UnAV_dataset(Dataset):

    def __init__(self, mode, pd_dir, audio_dir, visual_dir, label_dir, f_label_dir, f_lens):
        self.mode = mode
        self.video_list = pd.read_csv(pd_dir, header=0, sep=',')
        self.video_name = self.video_list["filename"]
        self.video_num = len(self.video_name)
        self.curr_m = 10
        self.future_m = 10

        self.audio_dir = audio_dir
        self.visual_dir = visual_dir
        self.label_dir = label_dir
        self.f_label_dir = f_label_dir
        self.f_lens = f_lens

        self._init_dataset()


    def _init_dataset(self):
        '''

        '''
        self.inputs = []
        for i in range(self.video_num):
            sample = self.video_list.loc[i,:]
            sample_name = sample[0]
            sample_length = sample[1]

            if self.mode != 'train': 
                for i in range(60): # 
                    idx = i if i < sample_length else sample_length - 1
                    self.inputs.append([sample_name, sample_length, idx])

            else:
                curr_idx_num = math.ceil((sample_length - self.future_m) / self.curr_m)
                for i in range(curr_idx_num):
                    start = i * self.curr_m
                    end = min((i + 1) * self.curr_m, sample_length-self.future_m+1) 
                    random_point = np.random.randint(start, end - 1)
                    self.inputs.append([sample_name, sample_length, random_point])


    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        name, length, curr_idx = self.inputs[idx]
        audio = np.load(os.path.join(self.audio_dir, name + '.npy')) # 6.4*T 768
        visual = np.load(os.path.join(self.visual_dir, name + '.npy')) # T 768
        label = np.load(os.path.join(self.label_dir, name + '.npy')) # T 100
        f_label = np.load(os.path.join(self.f_label_dir, name + '.npy')) # T 1536
        
        # padding
        audio_pad = np.zeros((64,768),dtype=np.float32)
        visual_pad = np.zeros((10,768),dtype=np.float32)
        label_pad = np.zeros((10,100),dtype=np.float32)
        f_label_pad = np.zeros((10,1536),dtype=np.float32)
        audio = np.concatenate((audio_pad, audio), axis=0)
        visual = np.concatenate((visual_pad, visual), axis=0)
        label = np.concatenate((label_pad, label), axis=0)
        f_label = np.concatenate((f_label_pad, f_label), axis=0)


        # sample
        sample = {}
        sample['name'] = name
        sample['length'] = length
        audio_index = int((curr_idx+1)*6.4)
        sample['audio'] = audio[audio_index:audio_index+64,:]
        sample['visual'] = visual[curr_idx+1:curr_idx+11,:]
        sample['curr_label'] = label[curr_idx+1:curr_idx+11,:]
        sample['curr_f_label'] = f_label[curr_idx+1:curr_idx+11,:]

        if self.mode == 'train':
            
            sample['future_label'] = label[curr_idx+11:curr_idx+11+self.f_lens,:]
            sample['future_f_label'] = f_label[curr_idx+11:curr_idx+11+self.f_lens,:]
        # pdb.set_trace()

        return sample