"""Implements dataloaders for AFFECT data."""
from typing import *
import pickle
import h5py
import numpy as np
import torch
from collections import defaultdict
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torchtext.vocab import GloVe
# Local import
from dataset.robustness.text_robust import add_text_noise
from dataset.robustness.timeseries_robust import add_timeseries_noise


def drop_entry(dataset):
    """Drop entries where there's no text in the data."""
    drop = []
    for ind, k in enumerate(dataset["text"]):
        if k.sum() == 0:
            drop.append(ind)
    for modality in list(dataset.keys()):
        dataset[modality] = np.delete(dataset[modality], drop, 0)
    return dataset


def get_rawtext(path, data_kind, vids):
    """Get raw text, video data from hdf5 file."""
    if data_kind == 'hdf5':
        f = h5py.File(path, 'r')
    else:
        with open(path, 'rb') as f_r:
            f = pickle.load(f_r)
    text_data = []
    new_vids = []

    for vid in vids:
        text = []
        # If data IDs are NOT the same as the raw ids
        # add some code to match them here, eg. from vanvan_10 to vanvan[10]
        # (id, seg) = re.match(r'([-\w]*)_(\w+)', vid).groups()
        # vid_id = '{}[{}]'.format(id, seg)
        vid_id = int(vid[0]) if type(vid) == np.ndarray else vid
        try:
            if data_kind == 'hdf5':
                for word in f['words'][vid_id]['features']:
                    if word[0] != b'sp':
                        text.append(word[0].decode('utf-8'))
                text_data.append(' '.join(text))
                new_vids.append(vid_id)
            else:
                for word in f[vid_id]:
                    if word != 'sp':
                        text.append(word)
                text_data.append(' '.join(text))
                new_vids.append(vid_id)
        except:
            print("missing", vid, vid_id)
    return text_data, new_vids


def _get_word2id(text_data, vids):
    word2id = defaultdict(lambda: len(word2id))
    UNK = word2id['unk']
    data_processed = dict()
    for i, segment in enumerate(text_data):
        words = []
        _words = segment.split()
        for word in _words:
            words.append(word2id[word])
        words = np.asarray(words)
        data_processed[vids[i]] = words

    def _return_unk():
        return UNK

    word2id.default_factory = _return_unk
    return data_processed, word2id


def _get_word_embeddings(word2id, save=False):
    vec = GloVe(name='840B', dim=300)
    tokens = []
    for w, _ in word2id.items():
        tokens.append(w)
    
    ret = vec.get_vecs_by_tokens(tokens, lower_case_backup=True)
    return ret


def _glove_embeddings(text_data, vids, paddings=50):
    data_prod, w2id = _get_word2id(text_data, vids)
    word_embeddings_looks_up = _get_word_embeddings(w2id)
    looks_up = word_embeddings_looks_up.numpy()
    embedd_data = []
    for vid in vids:
        d = data_prod[vid]
        tmp = []
        # Padding with zeros at the front
        # TODO: fix some segs have more than 50 words (FIXed)
        if len(d) > paddings:
            for x in d[:paddings]:
                tmp.append(looks_up[x])
        else:
            for i in range(paddings - len(d)):
                tmp.append(np.zeros(300, ))
            for x in d:
                tmp.append(looks_up[x])
        
        embedd_data.append(np.array(tmp))
    return np.array(embedd_data)


class Affect(Dataset):
    """Affect dataset pre-processed as part of MultiBench [1]
    It implements 4 dataset in one class: CMU-MOSEI, CMU-MOSI, MUSTARD, UR-FUNNY.
    All 4 dataset have 3 modalities coming from audiovisual inputs (i.e. videos):
        - vision: shape (*, T, pv) where (*) is batch dimension
        - audio: shape (*, T, pa)
        - text: shape (*, T, pt)
        where T is sequence length (different for each sample) and pv, pa, pt the features size.
    It comes with one task for each dataset related to sentiments/emotions (e.g. humor, sarcasm, fear, etc.)
    [1] MultiBench: Multiscale Benchmarks for Multimodal Representation Learning, Liang et al., NeurIPS Benchmarks 2021"""

    def __init__(self, data_path: str,
                 dataset: str,
                 split: str = "train",
                 modalities: Union[str, Tuple[str]] = ("vision", "audio", "text"),
                 task: str = "classification",
                 flatten_time_series: bool = False,
                 align: bool = True,
                 transform_vision: Optional[Callable] = None,
                 transform_audio: Optional[Callable] = None,
                 transform_text: Optional[Callable] = None,
                 noise_level: Optional[float] = None,
                 modalities_to_noise: Optional[Tuple[str]] = None,
                 raw_data_path: Optional[str] = None,
                 z_norm: bool = False):
        """
        Args:
            data_path: Datafile location.
            dataset: Dataset to be loaded, in {"mosei", "mosi", "humor", "sarcasm"}.
                NB: "sarcasm" == MUSTARD, "humor" == UR-FUNNY
            split: in {"train", "val", "test"}
            modalities: Modalities to return. NB: the order is preserved.
            task: either "classification" or "regression".
                If "classification", label is binarized in {0, 1}, otherwise it is left unchanged.
            flatten_time_series: Whether to flatten time series data or not.
            align: Whether to align data or not across modalities
            transform_vision: Vision transformations to apply
            transform_audio: Audio transformations to apply
            transform_audio: Text transformations to apply
            noise_level: If not None and > 0, add noise to modalities defined in `modalities_to_noise`
            modalities_to_noise: Which modality to noise.
            raw_data_path: Path to raw data to retrieve raw text (before pre-processing) when `noise_level` is set.
            z_norm: Whether to normalize data along the z dimension or not. Defaults to False.
        """
        self.data_path = data_path
        self.dataset = dataset
        self.split = split
        self.modalities = modalities
        self.task = task
        self.align = align
        self.flatten_time_series = flatten_time_series
        self.transform_vision = transform_vision
        self.transform_audio = transform_audio
        self.transform_text = transform_text
        self.noise_level = noise_level
        self.modalities_to_noise = modalities_to_noise
        self.raw_data_path = raw_data_path
        self.z_norm = z_norm

        if isinstance(self.modalities, str):
            self.modalities = (self.modalities,)

        with open(data_path, "rb") as f:
            data = pickle.load(f)
        split_ = split if split != "val" else "valid" # "val" -> "valid"
        data_split = data[split_]
        # Eventually add noise to data
        if noise_level is not None and noise_level > 0. and modalities_to_noise is not None:
            if "text" in modalities_to_noise:
                data_split["text"] = self.get_noisy_text(data_split["id"])
            if "vision" in modalities_to_noise:
                data_split["vision"] = self.get_noisy_vision(data_split["vision"])
            if "audio" in modalities_to_noise:
                data_split["audio"] = self.get_noisy_audio(data_split["audio"])
        # Drop samples without text
        self.dataset = drop_entry(data_split)
        # Removes `-inf`
        self.dataset['audio'][self.dataset['audio'] == -np.inf] = 0.0

    def get_noisy_text(self, ids):
        file_type = self.raw_data_path.split('.')[-1]  # hdf5
        rawtext, ids = get_rawtext(self.raw_data_path, file_type, ids)
        text = _glove_embeddings(add_text_noise(rawtext, noise_level=self.noise_level), ids)
        return text

    def get_noisy_vision(self, vision):
        vision = add_timeseries_noise([vision.copy()], noise_level=self.noise_level, rand_drop=False)[0]
        return vision

    def get_noisy_audio(self, audio):
        audio = add_timeseries_noise([audio.copy()], noise_level=self.noise_level, rand_drop=False)[0]
        return audio

    def __getitem__(self, ind):
        vision = self.dataset['vision'][ind]
        audio = self.dataset['audio'][ind]
        text = self.dataset['text'][ind]

        if self.align:
            start = text.nonzero()[0][0]
            vision = vision[start:].astype(np.float32)
            audio = audio[start:].astype(np.float32)
            text = text[start:].astype(np.float32)
        else:
            vision = vision[vision.nonzero()[0][0]:].astype(np.float32)
            audio = audio[audio.nonzero()[0][0]:].astype(np.float32)
            text = text[text.nonzero()[0][0]:].astype(np.float32)

        # z-normalize data
        def z_normalize(x: np.array): # normalize along first axis
            return (x - x.mean(axis=0, keepdims=True)) / np.std(x, axis=0, keepdims=True)

        if self.z_norm:
            vision = np.nan_to_num(z_normalize(vision))
            audio = np.nan_to_num(z_normalize(audio))
            text = np.nan_to_num(z_normalize(text))

        def _get_class(flag):
            if self.dataset != "humor":
                return [[1]] if flag > 0 else [[0]]
            else:
                return [flag]

        label = self.dataset['labels'][ind]
        label = _get_class(label) if self.task == "classification" else label

        modalities = dict(vision=vision, audio=audio, text=text)
        transforms = dict(vision=self.transform_vision, audio=self.transform_audio, text=self.transform_text)
        X, y = [], label
        for mod in self.modalities:
            if transforms[mod] is not None:
                modalities[mod] = transforms[mod](modalities[mod])
            if self.flatten_time_series:
                modalities[mod] = modalities[mod].flatten()
            X.append(modalities[mod])
        return X, y

    def __len__(self):
        """Get length of dataset."""
        return len(self.dataset['vision'])


def collate_fn_timeseries(inputs: List, max_seq_length: int = None):
    """Handles a list of timeseries data with eventually different lengths.
        Args:
             inputs: list of X where X is a list of modalities
             max_seq_length: if set, pads all timeseries to `max_seq_length`.
                Otherwise, all sequences are padded to the maximum sequence length in this batch.
        Output:
            X_ where X_ is a list of modalities with shape (*, T, p).
            If `max_seq_length` is set then T == `max_seq_length`.
    """
    X_padded = []  # List of padded modalities
    if len(inputs) > 0:
        n_mod = len(inputs[0])
        for i in range(n_mod):
            Xi = [torch.tensor(X[i]) for X in inputs]
            Xi_padded = pad_sequence(Xi, batch_first=True)  # shape (*, T, p)
            if max_seq_length is not None and max_seq_length > Xi_padded.shape[1]:
                Xi_padded = F.pad(Xi_padded, (0, 0, 0, max_seq_length - Xi_padded.shape[1]),
                                  "constant", 0)
            X_padded.append(Xi_padded)
    return X_padded



if __name__ == "__main__":
    import os,json
    from torch.utils.data import DataLoader
    catalog_path = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "catalog.json")
    with open(catalog_path) as f:
        catalog = json.load(f)
    datasets, samples = dict(), dict()
    for d in ["mosei"]:
        datasets[d] = Affect(catalog[d]["path"], d, "test")
