import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset
import math

def ids_to_multinomial(ids):
    """ label encoding
    Returns:
      1d array, multimonial representation, e.g. [1,0,1,0,0,...]
    """
    categories = ['Speech', 'Car', 'Cheering', 'Dog', 'Cat', 'Frying_(food)',
                  'Basketball_bounce', 'Fire_alarm', 'Chainsaw', 'Cello', 'Banjo',
                  'Singing', 'Chicken_rooster', 'Violin_fiddle', 'Vacuum_cleaner',
                  'Baby_laughter', 'Accordion', 'Lawn_mower', 'Motorcycle', 'Helicopter',
                  'Acoustic_guitar', 'Telephone_bell_ringing', 'Baby_cry_infant_cry', 'Blender',
                  'Clapping']
    id_to_idx = {id: index for index, id in enumerate(categories)}

    y = np.zeros(len(categories))
    for id in ids:
        index = id_to_idx[id]
        y[index] = 1
    return y

    

class on_LLP_dataset(Dataset):

    def __init__(self, mode, pd_dir, audio_dir, visual_dir, st_dir, label_dir, f_label_dir, f_lens, c_lens):
        self.mode = mode
        self.video_list = pd.read_csv(pd_dir, header=0, sep=',')
        self.video_name = self.video_list["filename"]
        self.video_num = len(self.video_name)
        self.curr_m = c_lens

        self.audio_dir = audio_dir
        self.visual_dir = visual_dir
        self.st_dir = st_dir
        self.label_dir = label_dir
        self.f_label_dir = f_label_dir
        self.f_lens = f_lens

        self._init_dataset()


    def _init_dataset(self):
        '''
        '''
        self.inputs = []
        for i in range(self.video_num):
            video = self.video_list.loc[i,:]
            sample_name = video[2]
            sample_length = video[1] * 10

            if self.mode == 'train': 
                curr_idx_num = math.ceil((sample_length - self.f_lens) / self.curr_m)
                for i in range(curr_idx_num):
                    start = i * self.curr_m
                    end = min((i + 1) * self.curr_m, sample_length-self.f_lens+1) 
                    random_point = np.random.randint(start, end - 1)
                    # event_label = video[i+3].split(",")
                    self.inputs.append([sample_name, sample_length, random_point, ['none']])
                        
            else:       
                for i in range(60): 
                    idx = i if i < sample_length else sample_length - 1
                    self.inputs.append([sample_name, sample_length, idx, ['none']])



    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        sample_name, sample_length, curr_idx, event_label = self.inputs[idx]

        audio = np.load(os.path.join(self.audio_dir, sample_name + '.npy')) # 6.4*T 768
        visual = np.load(os.path.join(self.visual_dir, sample_name + '.npy')) # T 768
        st = np.load(os.path.join(self.st_dir, sample_name + '.npy')) # T 512
        label_a = np.load(os.path.join(self.label_dir, 'audio', sample_name + '.npy')) # T 25
        label_v = np.load(os.path.join(self.label_dir, 'visual', sample_name + '.npy')) # T 25
        f_label_a = np.load(os.path.join(self.f_label_dir, 'audio', sample_name + '.npy')) # T 1536
        f_label_v = np.load(os.path.join(self.f_label_dir, 'visual', sample_name + '.npy')) # T 1536

        # padding
        audio_pad = np.zeros((int(6.4* self.curr_m),768),dtype=np.float32)
        visual_pad = np.zeros((self.curr_m,768),dtype=np.float32)
        st_pad = np.zeros((self.curr_m,512),dtype=np.float32)
        label_a_pad = np.zeros((self.curr_m,25),dtype=np.float32)
        label_v_pad = np.zeros((self.curr_m,25),dtype=np.float32)
        f_label_a_pad = np.zeros((self.curr_m,1536),dtype=np.float32)
        f_label_v_pad = np.zeros((self.curr_m,1536),dtype=np.float32)

        audio = np.concatenate((audio_pad, audio), axis=0)
        visual = np.concatenate((visual_pad, visual), axis=0)
        st = np.concatenate((st_pad, st), axis=0)
        label_a = np.concatenate((label_a_pad, label_a), axis=0)
        label_v = np.concatenate((label_v_pad, label_v), axis=0)
        f_label_a = np.concatenate((f_label_a_pad, f_label_a), axis=0)
        f_label_v = np.concatenate((f_label_v_pad, f_label_v), axis=0)

        # sample
        sample={}

        sample['name'] = sample_name
        sample['length'] = sample_length
        audio_index = int((curr_idx+1)*6.4)
        sample['audio'] = audio[audio_index:audio_index+int(6.4* self.curr_m),:]
        sample['visual'] = visual[curr_idx+1:curr_idx+1+self.curr_m,:]
        sample['st'] = st[curr_idx+1:curr_idx+1+self.curr_m,:]
        sample['curr_label_a'] = label_a[curr_idx+1:curr_idx+1+self.curr_m,:]
        sample['curr_f_label_a'] = f_label_a[curr_idx+1:curr_idx+1+self.curr_m,:]
        sample['curr_label_v'] = label_v[curr_idx+1:curr_idx+1+self.curr_m,:]
        sample['curr_f_label_v'] = f_label_v[curr_idx+1:curr_idx+1+self.curr_m,:]

        if self.mode == 'train':
            sample['futu_label_a'] = label_a[curr_idx+1+self.curr_m:curr_idx+1+self.curr_m+self.f_lens,:]
            sample['futu_f_label_a'] = f_label_a[curr_idx+1+self.curr_m:curr_idx+1+self.curr_m+self.f_lens,:]
            sample['futu_label_v'] = label_v[curr_idx+1+self.curr_m:curr_idx+1+self.curr_m+self.f_lens,:]
            sample['futu_f_label_v'] = f_label_v[curr_idx+1+self.curr_m:curr_idx+1+self.curr_m+self.f_lens,:]

            # sample['event_label'] = ids_to_multinomial(event_label) # 1,25



        return sample
    

