import numpy as np
import pandas as pd
import feather
import random
import torch
import h5py
from torch.utils.data import Dataset
import concurrent.futures
import time


# Visual Modal
def load_pose_gaze_aus(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), t_len, 47), dtype=np.float32)
    else:
        df = pd.read_feather(path).fillna(0)
        selected_rows = df.iloc[:, 4:].to_numpy().astype(np.float32)
        tensors = np.zeros((len(start_list), t_len, 47), dtype=np.float32)
        start_indices = (start_list * max((len(df) - t_len), 0)).astype(int)
        tensors[:, :min(t_len, len(df)), :] = selected_rows[start_indices[:, None] + np.arange(min(t_len, len(df)))]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_pose_gaze_aus: {(T2-T1)*1000:.8f}')
    return tensors


def load_landmark(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), t_len, 136), dtype=np.float32)
    else:
        df = pd.read_feather(path).fillna(0)
        selected_rows = df.iloc[:, 4:].to_numpy().astype(np.float32)
        tensors = np.zeros((len(start_list), t_len, 136), dtype=np.float32)
        start_indices = (start_list * max((len(df) - t_len), 0)).astype(int)
        tensors[:, :min(t_len, len(df)), :] = selected_rows[start_indices[:, None] + np.arange(min(t_len, len(df)))]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_landmark: {(T2 - T1) * 1000:.8f}')
    return tensors


def load_hog(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), t_len, 4464), dtype=np.float32)
    else:
        df = np.memmap(path, dtype='float32', mode='r')
        df = df.reshape((df.shape[0] // 4465, 4465))
        selected_rows = df[:, 1:]
        tensors = np.zeros((len(start_list), t_len, 4464), dtype=np.float32)
        start_indices = (start_list * max((df.shape[0] - t_len), 0)).astype(int)
        tensors[:, :min(t_len, len(df)), :] = selected_rows[start_indices[:, None] + np.arange(min(t_len, df.shape[0]))]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_hog: {(T2 - T1) * 1000:.8f}')
    return tensors


# Textual Modal
def load_bert(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), 1, 768), dtype=np.float32)
    else:
        df = np.memmap(path, dtype='float32', mode='r')
        df = df.reshape((df.shape[0] // 769, 769))
        selected_rows = df[:, 1:]
        tensors = np.empty((len(start_list), df.shape[0], 768), dtype=np.float32)
        start_indices = (start_list * 0).astype(int)
        tensors[:, :, :] = selected_rows[start_indices[:, None] + np.arange(df.shape[0])]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_bert: {(T2 - T1) * 1000:.8f}')
    return tensors


# Acoustic Modal
def load_mfcc(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), t_len, 39), dtype=np.float32)
    else:
        df = pd.read_feather(path).fillna(0).drop(columns='name')
        selected_rows = df.iloc[:, 1:].to_numpy().astype(np.float32)
        tensors = np.zeros((len(start_list), t_len, 39), dtype=np.float32)
        start_indices = (start_list * max((len(df) - t_len), 0)).astype(int)
        tensors[:, :min(t_len, len(df)), :] = selected_rows[start_indices[:, None] + np.arange(min(t_len, len(df)))]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_mfcc: {(T2 - T1) * 1000:.8f}')
    return tensors


def load_egemaps(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), t_len, 88), dtype=np.float32)
    else:
        df = pd.read_feather(path).fillna(0).drop(columns='name')
        selected_rows = df.iloc[:, 1:].to_numpy().astype(np.float32)
        tensors = np.zeros((len(start_list), t_len, 88), dtype=np.float32)
        start_indices = (start_list * max((len(df) - t_len), 0)).astype(int)
        tensors[:, :min(t_len, len(df)), :] = selected_rows[start_indices[:, None] + np.arange(min(len(df), t_len))]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_egemaps: {(T2 - T1) * 1000:.8f}')
    return tensors


def load_covarep(path, start_list, t_len):
    # T1 = time.perf_counter()
    if pd.isna(path):
        tensors = np.zeros((len(start_list), t_len, 81), dtype=np.float32)
    else:
        df = pd.read_feather(path).fillna(0)
        selected_rows = df.iloc[:, 1:].to_numpy().astype(np.float32)
        # selected_rows[:, -6:-1] = 0.0001 * selected_rows[:, -6:-1]
        # selected_rows[:, 0] = 0.001 * selected_rows[:, 0]
        tensors = np.zeros((len(start_list), t_len, 81), dtype=np.float32)
        start_indices = (start_list * max((len(df) - t_len), 0)).astype(int)
        tensors[:, :min(t_len, len(df)), :] = selected_rows[start_indices[:, None] + np.arange(min(t_len, len(df)))]
        tensors[np.isinf(tensors)] = 0
    # T2 = time.perf_counter()
    # print(f'load_covarep: {(T2 - T1) * 1000:.8f}')
    return tensors


def load_files_parallel(filepaths, load_funcs, rd_flags, tlens):
    with concurrent.futures.ThreadPoolExecutor(max_workers=7) as executor:
        results = list(executor.map(lambda load_func, path, rd_flag, tlen: load_func(path, rd_flag, tlen),
                                    load_funcs, filepaths, rd_flags, tlens))
    return results


class MultimodalDataset(Dataset):
    def __init__(self, csv_path, split=None, t_len=64):
        super(MultimodalDataset, self).__init__()
        if split is None:
            split = 'Training'
        df = pd.read_csv(csv_path)
        df = df[df['train_split'].isin([split])]
        self.csv_file = df.reset_index(drop=True)
        self.t_len = t_len

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx):
        data_line = self.csv_file.loc[idx]
        dataset_name = data_line.loc['dataset']
        if dataset_name == 'avec14':
            dataset_flag = 63
        elif dataset_name == 'cmdc':
            dataset_flag = 27
        else:
            dataset_flag = 24
        pose_gaze_aus_path = data_line.loc['pose_gaze_aus']
        landmark_path = data_line.loc['landmark']
        hog_path = data_line.loc['hog']
        mfcc_path = data_line.loc['mfcc']
        egemaps_path = data_line.loc['egemaps']
        covarep_path = data_line.loc['covarep']
        bert_path = data_line.loc['bert']
        label = data_line.loc['label']
        label_norm = data_line.loc['label_norm']
        label = torch.tensor(label, dtype=torch.float)
        label_norm = torch.tensor(label_norm, dtype=torch.float)

        rd_flag = np.array([random.random()])
        # max_len = len(pd.read_feather(pose_gaze_aus_path))
        # if self.t_len > max_len:
        #     t_len = max_len - 10
        # else:
        #     t_len = self.t_len
        t_len = self.t_len

        filepaths = [pose_gaze_aus_path, landmark_path, hog_path, mfcc_path, egemaps_path, covarep_path, bert_path]
        load_funcs = [load_pose_gaze_aus, load_landmark, load_hog, load_mfcc, load_egemaps, load_covarep, load_bert]
        rd_flags = [rd_flag] * 7
        tlens = [t_len, t_len, t_len, int(t_len * 3.33), int(t_len * 3.33), int(t_len * 3.33), 0]
        results_list = load_files_parallel(filepaths, load_funcs, rd_flags, tlens)
        # vision fps: 30, total dim: 4647
        pga = results_list[0].squeeze(0)  # [t_len, 47]
        lm = results_list[1].squeeze(0)  # [t_len, 136]
        hog = results_list[2].squeeze(0)  # [t_len, 4464]
        # audio fps: 100, total dim: 208
        mfcc = results_list[3].squeeze(0)  # [t_len, 39]
        egemaps = results_list[4].squeeze(0)  # [t_len, 88]
        covarep = results_list[5].squeeze(0)  # [t_len, 81]
        # text take all length, total dim: 768
        bert = results_list[6].squeeze(0)  # [t_len, 768]

        data_v = np.concatenate((pga, lm, hog), axis=1).astype(np.float32)  # [t_len, 4647]
        data_a = np.concatenate((mfcc, egemaps, covarep), axis=1).astype(np.float32)  # [t_len, 208]
        data_t = bert.astype(np.float32)

        # if self.t_len > max_len:
        #     added_len = self.t_len - t_len
        #     zeros = np.zeros((added_len, 4647), dtype=np.float32)
        #     data_v = np.concatenate((data_v, zeros), axis=0)
        #
        #     added_len = int(self.t_len * 3.33) - int(t_len * 3.33)
        #     zeros = np.zeros((added_len, 208), dtype=np.float32)
        #     data_a = np.concatenate((data_a, zeros), axis=0)

        return torch.from_numpy(data_v), torch.from_numpy(data_a), torch.from_numpy(data_t), label_norm, label, dataset_flag


class MultimodalValDataset(Dataset):
    def __init__(self, csv_path, split=None, t_len=64):
        super(MultimodalValDataset, self).__init__()
        if split is None:
            split = 'Testing'
        df = pd.read_csv(csv_path)
        df = df[df['train_split'].isin([split])]
        self.csv_file = df.reset_index(drop=True)
        self.t_len = t_len

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx):
        data_line = self.csv_file.loc[idx]
        dataset_name = data_line.loc['dataset']
        if dataset_name == 'avec14':
            dataset_flag = 63
        elif dataset_name == 'cmdc':
            dataset_flag = 27
        else:
            dataset_flag = 24
        pose_gaze_aus_path = data_line.loc['pose_gaze_aus']
        landmark_path = data_line.loc['landmark']
        hog_path = data_line.loc['hog']
        mfcc_path = data_line.loc['mfcc']
        egemaps_path = data_line.loc['egemaps']
        covarep_path = data_line.loc['covarep']
        bert_path = data_line.loc['bert']
        label = data_line.loc['label']
        label_norm = data_line.loc['label_norm']
        label = torch.tensor(label, dtype=torch.float)
        label_norm = torch.tensor(label_norm, dtype=torch.float)

        rd_flag = np.arange(0, 1, 0.1)
        # max_len = len(pd.read_feather(pose_gaze_aus_path))
        # if self.t_len > max_len:
        #     t_len = max_len - 10
        # else:
        #     t_len = self.t_len
        t_len = self.t_len

        filepaths = [pose_gaze_aus_path, landmark_path, hog_path, mfcc_path, egemaps_path, covarep_path, bert_path]
        load_funcs = [load_pose_gaze_aus, load_landmark, load_hog, load_mfcc, load_egemaps, load_covarep, load_bert]
        rd_flags = [rd_flag] * 7
        tlens = [t_len, t_len, t_len, int(t_len * 3.33), int(t_len * 3.33), int(t_len * 3.33), 0]
        results_list = load_files_parallel(filepaths, load_funcs, rd_flags, tlens)
        # vision fps: 30, total dim: 4647
        pga = results_list[0]  # [10, t_len, 47]
        lm = results_list[1]  # [10, t_len, 136]
        hog = results_list[2]  # [10, t_len, 4464]
        # audio fps: 100, total dim: 208
        mfcc = results_list[3]  # [10, t_len, 39]
        egemaps = results_list[4]  # [10, t_len, 88]
        covarep = results_list[5]  # [10, t_len, 81]
        # text take all length, total dim: 768
        bert = results_list[6]  # [10, t_len, 768]

        data_v = np.concatenate((pga, lm, hog), axis=2).astype(np.float32)  # [10, t_len, 4647]
        data_a = np.concatenate((mfcc, egemaps, covarep), axis=2).astype(np.float32)  # [10, t_len, 208]
        data_t = bert.astype(np.float32)

        # if self.t_len > max_len:
        #     added_len = self.t_len - t_len
        #     zeros = np.zeros((len(rd_flag), added_len, 4647), dtype=np.float32)
        #     data_v = np.concatenate((data_v, zeros), axis=1)
        #
        #     added_len = int(self.t_len * 3.33) - int(t_len * 3.33)
        #     zeros = np.zeros((len(rd_flag), added_len, 208), dtype=np.float32)
        #     data_a = np.concatenate((data_a, zeros), axis=1)

        return torch.from_numpy(data_v), torch.from_numpy(data_a), torch.from_numpy(data_t), label_norm, label, dataset_flag


if __name__ == '__main__':
    from torch.utils.data.dataloader import DataLoader
    from torch.nn.utils.rnn import pad_packed_sequence
    import os


    from utils import collate_fn_val

    test_dataset = MultimodalValDataset('dataset/all_data_path.csv', 'Testing', 64)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=2,
                                 collate_fn=collate_fn_val,
                                 shuffle=False,
                                 drop_last=False)

    flag = 0
    for step, (test_v_pack, test_a_pack, test_t_pack, _, score, d_flag) in enumerate(test_dataloader):
        # if flag > 147:
        # unpacked_data_t, lengths = pad_packed_sequence(test_t_pack, batch_first=True)
        # print(lengths)
        print(test_t_pack.shape)
        flag += 1
        # break
