import os
import sys
from typing import *
import pickle
import h5py
import numpy as np
from numpy.core.numeric import zeros_like
from torch.nn.functional import pad
from torch.nn import functional as F
from pathlib import Path

# allow running from project root
project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(str(project_root))

from config import MULTIBENCH_MOSI_PATH

import torch
try:
    import torchtext as text
except Exception:
    text = None
from collections import defaultdict
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, default_collate

# optional robustness modules; if missing, related functionality will be disabled
try:
    from robustness.text_robust import add_text_noise
    from robustness.timeseries_robust import add_timeseries_noise
except Exception:
    add_text_noise = None
    add_timeseries_noise = None

np.seterr(divide='ignore', invalid='ignore')


class MMIMDbDataset(Dataset):
    """Implements MM-IMDb dataset for multimodal learning with text and image features.
    
    Dataset contains:
    - Text features: 300-dim (pre-computed)
    - VGG features: 4096-dim (from VGG-16)
    - Multi-label genres: 23 classes (binary indicators)
    
    Genre labels (in order):
    0. Drama, 1. Comedy, 2. Romance, 3. Thriller, 4. Crime, 5. Action, 
    6. Adventure, 7. Horror, 8. Documentary, 9. Mystery, 10. Sci-Fi, 11. Fantasy,
    12. Family, 13. Biography, 14. War, 15. History, 16. Music, 17. Animation,
    18. Musical, 19. Western, 20. Sport, 21. Short, 22. Film-Noir
    """
    
    # Class-level genre names for reference
    GENRE_NAMES = [
        "Drama", "Comedy", "Romance", "Thriller", "Crime", "Action", "Adventure", 
        "Horror", "Documentary", "Mystery", "Sci-Fi", "Fantasy", "Family", "Biography", 
        "War", "History", "Music", "Animation", "Musical", "Western", "Sport", "Short", "Film-Noir"
    ]
    
    def __init__(self, h5_file: str, start_ind: int, end_ind: int) -> None:
        """Initialize MMIMDbDataset.
        
        Args:
            h5_file: Path to h5py file containing features, vgg_features, and genres
            start_ind: Starting index for this split
            end_ind: Ending index for this split
        """
        self.h5_file = h5_file
        self.start_ind = start_ind
        self.end_ind = end_ind
        self.size = end_ind - start_ind
        
    def __getitem__(self, ind):
        """Get item from dataset.
        
        Returns:
            tuple: (text_features, vgg_features, label)
        """
        if not hasattr(self, 'dataset'):
            self.dataset = h5py.File(self.h5_file, 'r')
        
        actual_ind = ind + self.start_ind
        text = torch.FloatTensor(self.dataset["features"][actual_ind])
        image = torch.FloatTensor(self.dataset["vgg_features"][actual_ind])
        label = torch.FloatTensor(self.dataset["genres"][actual_ind])
        
        return text, image, label
    
    def __len__(self):
        return self.size


def drop_entry(dataset):
    """Drop entries where there's no text in the data."""
    drop = []
    for ind, k in enumerate(dataset["text"]):
        if k.sum() == 0:
            drop.append(ind)
    for modality in list(dataset.keys()):
        dataset[modality] = np.delete(dataset[modality], drop, 0)
    return dataset


def z_norm(dataset, max_seq_len=50):
    """Normalize data in the dataset."""
    processed = {}
    text = dataset['text'][:, :max_seq_len, :]
    vision = dataset['vision'][:, :max_seq_len, :]
    audio = dataset['audio'][:, :max_seq_len, :]
    for ind in range(dataset["text"].shape[0]):
        vision[ind] = np.nan_to_num(
            (vision[ind] - vision[ind].mean(0, keepdims=True)) / (np.std(vision[ind], axis=0, keepdims=True)))
        audio[ind] = np.nan_to_num(
            (audio[ind] - audio[ind].mean(0, keepdims=True)) / (np.std(audio[ind], axis=0, keepdims=True)))
        text[ind] = np.nan_to_num(
            (text[ind] - text[ind].mean(0, keepdims=True)) / (np.std(text[ind], axis=0, keepdims=True)))

    processed['vision'] = vision
    processed['audio'] = audio
    processed['text'] = text
    processed['labels'] = dataset['labels']
    return processed


def get_rawtext(path, data_kind, vids):
    """Get raw text, video data from hdf5 file."""
    if data_kind == 'hdf5':
        f = h5py.File(path, 'r')
    else:
        with open(path, 'rb') as f_r:
            f = pickle.load(f_r)
    text_data = []
    new_vids = []

    for vid in vids:
        text = []
        vid_id = int(vid[0]) if type(vid) == np.ndarray else vid
        try:
            if data_kind == 'hdf5':
                for word in f['words'][vid_id]['features']:
                    if word[0] != b'sp':
                        text.append(word[0].decode('utf-8'))
                text_data.append(' '.join(text))
                new_vids.append(vid_id)
            else:
                for word in f[vid_id]:
                    if word != 'sp':
                        text.append(word)
                text_data.append(' '.join(text))
                new_vids.append(vid_id)
        except Exception:
            print("missing", vid, vid_id)
    return text_data, new_vids


def _get_word2id(text_data, vids):
    word2id = defaultdict(lambda: len(word2id))
    UNK = word2id['unk']
    data_processed = dict()
    for i, segment in enumerate(text_data):
        words = []
        _words = segment.split()
        for word in _words:
            words.append(word2id[word])
        words = np.asarray(words)
        data_processed[vids[i]] = words

    def _return_unk():
        return UNK

    word2id.default_factory = _return_unk
    return data_processed, word2id


def _get_word_embeddings(word2id, save=False):
    if text is None:
        raise RuntimeError("torchtext is required for GloVe embeddings but it's not available")
    vec = text.vocab.GloVe(name='840B', dim=300)
    tokens = []
    for w, _ in word2id.items():
        tokens.append(w)
    ret = vec.get_vecs_by_tokens(tokens, lower_case_backup=True)
    return ret


def _glove_embeddings(text_data, vids, paddings=50):
    data_prod, w2id = _get_word2id(text_data, vids)
    word_embeddings_looks_up = _get_word_embeddings(w2id)
    looks_up = word_embeddings_looks_up.numpy()
    embedd_data = []
    for vid in vids:
        d = data_prod[vid]
        tmp = []
        # Padding with zeros at the front
        if len(d) > paddings:
            for x in d[:paddings]:
                tmp.append(looks_up[x])
        else:
            for i in range(paddings - len(d)):
                tmp.append(np.zeros(300, ))
            for x in d:
                tmp.append(looks_up[x])
        embedd_data.append(np.array(tmp))
    return np.array(embedd_data)


class Affectdataset(Dataset):
    """Implements Affect data as a torch dataset."""
    def __init__(self, data: Dict, flatten_time_series: bool, aligned: bool = True, task: str = None, max_pad=False, max_pad_num=50, data_type='mosi', z_norm=False) -> None:
        self.dataset = data
        self.flatten = flatten_time_series
        self.aligned = aligned
        self.task = task
        self.max_pad = max_pad
        self.max_pad_num = max_pad_num
        self.data_type = data_type
        self.z_norm = z_norm
        self.dataset['audio'][self.dataset['audio'] == -np.inf] = 0.0

    def __getitem__(self, ind):
        vision = torch.tensor(self.dataset['vision'][ind])
        audio = torch.tensor(self.dataset['audio'][ind])
        text = torch.tensor(self.dataset['text'][ind])

        if self.aligned:
            try:
                start = text.nonzero(as_tuple=False)[0][0]
            except Exception:
                print(text, ind)
                raise
            vision = vision[start:].float()
            audio = audio[start:].float()
            text = text[start:].float()
        else:
            vision = vision[vision.nonzero()[0][0]:].float()
            audio = audio[audio.nonzero()[0][0]:].float()
            text = text[text.nonzero()[0][0]:].float()

        if self.z_norm:
            vision = torch.nan_to_num((vision - vision.mean(0, keepdims=True)) / (torch.std(vision, axis=0, keepdims=True)))
            audio = torch.nan_to_num((audio - audio.mean(0, keepdims=True)) / (torch.std(audio, axis=0, keepdims=True)))
            text = torch.nan_to_num((text - text.mean(0, keepdims=True)) / (torch.std(text, axis=0, keepdims=True)))

        def _get_class(flag, data_type=self.data_type):
            if data_type in ['mosi', 'mosei', 'sarcasm']:
                if flag > 0:
                    return [[1]]
                else:
                    return [[0]]
            else:
                return [flag]

        tmp_label = self.dataset['labels'][ind]
        if self.data_type == 'humor' or self.data_type == 'sarcasm':
            if (self.task == None) or (self.task == 'regression'):
                if self.dataset['labels'][ind] < 1:
                    tmp_label = [[-1]]
                else:
                    tmp_label = [[1]]
        else:
            tmp_label = self.dataset['labels'][ind]

        label = torch.tensor(_get_class(tmp_label)).long() if self.task == "classification" else torch.tensor(tmp_label).float()

        if self.flatten:
            return [vision.flatten(), audio.flatten(), text.flatten(), ind, label]
        else:
            if self.max_pad:
                tmp = [vision, audio, text, label]
                for i in range(len(tmp) - 1):
                    tmp[i] = tmp[i][:self.max_pad_num]
                    tmp[i] = F.pad(tmp[i], (0, 0, 0, self.max_pad_num - tmp[i].shape[0]))
            else:
                tmp = [vision, audio, text, ind, label]
            return tmp

    def __len__(self):
        return self.dataset['vision'].shape[0]


def get_dataloader(
        filepath: str, batch_size: int = 32, max_seq_len=50, max_pad=False, train_shuffle: bool = True,
        num_workers: int = 2, flatten_time_series: bool = False, task=None, robust_test: bool = False, data_type='mosi', 
        raw_path=None, z_norm=False) -> DataLoader:
    """Get dataloaders for affect data.

    Returns tuple (train, valid, test)
    """
    # Use config path if raw_path not provided
    if raw_path is None:
        raw_path = MULTIBENCH_MOSI_PATH
    
    with open(filepath, "rb") as f:
        alldata = pickle.load(f)

    processed_dataset = {'train': {}, 'test': {}, 'valid': {}}
    alldata['train'] = drop_entry(alldata['train'])
    alldata['valid'] = drop_entry(alldata['valid'])
    alldata['test'] = drop_entry(alldata['test'])

    process = eval("_process_2") if max_pad else eval("_process_1")

    for dataset in alldata:
        processed_dataset[dataset] = alldata[dataset]

    train = DataLoader(Affectdataset(processed_dataset['train'], flatten_time_series, task=task, max_pad=max_pad,               max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), \
                       shuffle=train_shuffle, num_workers=num_workers, batch_size=batch_size, \
                       collate_fn=process)
    valid = DataLoader(Affectdataset(processed_dataset['valid'], flatten_time_series, task=task, max_pad=max_pad, max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), \
                       shuffle=False, num_workers=num_workers, batch_size=batch_size, \
                       collate_fn=process)

    if robust_test:
        vids = [id for id in alldata['test']['id']]

        file_type = raw_path.split('.')[-1]  # hdf5
        rawtext, vids = get_rawtext(raw_path, file_type, vids)

        # Add text noises
        robust_text = []
        robust_text_numpy = []
        for i in range(10):
            test = dict()
            test['vision'] = alldata['test']["vision"]
            test['audio'] = alldata['test']["audio"]
            test['text'] = _glove_embeddings(add_text_noise(rawtext, noise_level=i / 10), vids)
            test['labels'] = alldata['test']["labels"]
            test = drop_entry(test)

            robust_text_numpy.append(test['text'])

            robust_text.append(
                DataLoader(Affectdataset(test, flatten_time_series, task=task, max_pad=max_pad, max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), shuffle=False, num_workers=num_workers,
                        batch_size=batch_size, collate_fn=process))

        # Add visual noises
        robust_vision = []
        for i in range(10):
            test = dict()
            test['vision'] = add_timeseries_noise([alldata['test']['vision'].copy()], noise_level=i / 10, rand_drop=False)[0]
            
            test['audio'] = alldata['test']["audio"].copy()
            test['text'] = alldata['test']['text'].copy()
            test['labels'] = alldata['test']["labels"]
            test = drop_entry(test)
            print('test entries: {}'.format(test['vision'].shape))

            robust_vision.append(
                DataLoader(Affectdataset(test, flatten_time_series, task=task, max_pad=max_pad, max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), shuffle=False, num_workers=num_workers,
                        batch_size=batch_size, collate_fn=process))

        # Add audio noises
        robust_audio = []
        for i in range(10):
            test = dict()
            test['vision'] = alldata['test']["vision"].copy()
            test['audio'] = add_timeseries_noise([alldata['test']["audio"].copy()], noise_level=i / 10, rand_drop=False)[0]
            test['text'] = alldata['test']['text'].copy()
            test['labels'] = alldata['test']["labels"]
            test = drop_entry(test)
            print('test entries: {}'.format(test['vision'].shape))

            robust_audio.append(
                DataLoader(Affectdataset(test, flatten_time_series, task=task, max_pad=max_pad, max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), shuffle=False, num_workers=num_workers,
                        batch_size=batch_size, collate_fn=process))

        robust_timeseries = []
        for i in range(10):
            robust_timeseries_tmp = add_timeseries_noise(
                [alldata['test']['vision'].copy(), alldata['test']['audio'].copy(), alldata['test']['text'].copy()],
                noise_level=i / (10 * 3), rand_drop=False)
            
            test = dict()
            test['vision'] = robust_timeseries_tmp[0]
            test['audio'] = robust_timeseries_tmp[1]
            test['text'] = robust_timeseries_tmp[2]
            test['labels'] = alldata['test']['labels']
            test = drop_entry(test)
            print('test entries: {}'.format(test['vision'].shape))

            robust_timeseries.append(
                DataLoader(Affectdataset(test, flatten_time_series, task=task, max_pad=max_pad, max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), shuffle=False, num_workers=num_workers,
                        batch_size=batch_size, collate_fn=process))
        test_robust_data = dict()
        test_robust_data['robust_text'] = robust_text
        test_robust_data['robust_vision'] = robust_vision
        test_robust_data['robust_audio'] = robust_audio
        test_robust_data['robust_timeseries'] = robust_timeseries
        return train, valid, test_robust_data
    else:
        test = DataLoader(Affectdataset(processed_dataset['test'], flatten_time_series, task=task, max_pad=max_pad, max_pad_num=max_seq_len, data_type=data_type, z_norm=z_norm), \
                      shuffle=False, num_workers=num_workers, batch_size=batch_size, \
                      collate_fn=process)
        return train, valid, test


def _process_1(inputs: List):
    processed_input = []
    processed_input_lengths = []
    inds = []
    labels = []

    for i in range(len(inputs[0]) - 2):
        feature = []
        for sample in inputs:
            feature.append(sample[i])
        processed_input_lengths.append(torch.as_tensor([v.size(0) for v in feature]))
        pad_seq = pad_sequence(feature, batch_first=True)
        processed_input.append(pad_seq)

    for sample in inputs:
        inds.append(sample[-2])
        # Be robust to label tensor shapes: allow 0-D, 1-D, or 2-D tensors
        lab = sample[-1]
        try:
            lab_ndim = lab.dim()
        except Exception:
            # Fallback for non-tensor labels
            labels.append(lab)
            continue

        if lab_ndim >= 2:
            # If label is shaped (1, K) or (K, 1), try to flatten to a single value
            if lab.shape[1] > 1:
                labels.append(lab.reshape(lab.shape[1], lab.shape[0])[0])
            else:
                labels.append(lab)
        else:
            labels.append(lab)

    return processed_input, processed_input_lengths, \
           torch.tensor(inds).view(len(inputs), 1), torch.tensor(labels).view(len(inputs), 1)


def _process_2(inputs: List):
    processed_input = []
    processed_input_lengths = []
    labels = []

    for i in range(len(inputs[0]) - 1):
        feature = []
        for sample in inputs:
            feature.append(sample[i])
        # compute per-sample lengths and pad to the max length in the batch
        processed_input_lengths.append(torch.as_tensor([v.size(0) for v in feature]))
        # use pad_sequence so variable-length time series are padded instead of failing on stack
        processed_input.append(pad_sequence(feature, batch_first=True))

    for sample in inputs:
        # Be robust to label tensor shapes: allow 0-D, 1-D, or 2-D tensors
        lab = sample[-1]
        try:
            lab_ndim = lab.dim()
        except Exception:
            labels.append(lab)
            continue

        if lab_ndim >= 2:
            if lab.shape[1] > 1:
                labels.append(lab.reshape(lab.shape[1], lab.shape[0])[0])
            else:
                labels.append(lab)
        else:
            labels.append(lab)

    return processed_input[0], processed_input[1], processed_input[2], torch.tensor(labels).view(len(inputs), 1)


def _process_mmimdb(inputs: List):
    """Collate function for MM-IMDb dataset (text, image, multi-label).
    
    Returns:
        tuple: (text_batch, image_batch, label_batch)
    """
    text_features = []
    image_features = []
    labels = []
    
    for sample in inputs:
        text_features.append(sample[0])
        image_features.append(sample[1])
        labels.append(sample[2])
    
    # Stack into batches (MM-IMDb features are already fixed-size vectors)
    text_batch = torch.stack(text_features, dim=0)
    image_batch = torch.stack(image_features, dim=0)
    label_batch = torch.stack(labels, dim=0)
    
    return text_batch, image_batch, label_batch


# --- Compatibility wrapper for repo API used in experiments ---
def get_dataloaders(dataset_name: str, data_dir: str, aggregation=None, batch_size: int = 32, num_workers: int = 4):
    """Compatibility wrapper: map dataset name to pack file and return train/valid/test DataLoaders.

    - dataset_name: 'mosi','mosei','ur_funny' (or 'humor'), 'mmimdb'
    - data_dir: parent dir containing 'pack' folder with pkl files or h5 files
    - aggregation: ignored for this loader (use internal collate to control padding)
    """
    # Handle MM-IMDb separately (uses h5 file, not pickle)
    if dataset_name.lower() == 'mmimdb':
        h5_path = os.path.join(data_dir, 'multimodal_imdb.hdf5')
        if not os.path.exists(h5_path):
            raise FileNotFoundError(f"MM-IMDb h5 file not found: {h5_path}")
        
        # Standard splits from original repo: train=0-15552, val=15552-18160, test=18160-25959
        train_dataset = MMIMDbDataset(h5_path, 0, 15552)
        valid_dataset = MMIMDbDataset(h5_path, 15552, 18160)
        test_dataset = MMIMDbDataset(h5_path, 18160, 25959)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                 num_workers=num_workers, collate_fn=_process_mmimdb)
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
                                 num_workers=num_workers, collate_fn=_process_mmimdb)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                                num_workers=num_workers, collate_fn=_process_mmimdb)
        
        return train_loader, valid_loader, test_loader
    
    # Handle affect datasets (MOSI, MOSEI, UR_FUNNY, etc.)
    file_map = {
        'mosi': 'mosi_raw.pkl',
        'mosei': 'mosei_raw.pkl',
        'ur_funny': 'humor.pkl',
        'humor': 'humor.pkl',
        'sarcasm': 'sarcasm.pkl'
    }
    if dataset_name.lower() not in file_map:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    pack_dir = os.path.join(data_dir, 'pack')
    full_path = os.path.join(pack_dir, file_map[dataset_name.lower()])
    if not os.path.exists(full_path):
        raise FileNotFoundError(f"Expected pack file not found: {full_path}")

    # Inspect pack to determine global max sequence length per modality and call loader with global padding
    with open(full_path, 'rb') as f:
        packed = pickle.load(f)

    # Determine maximum temporal lengths across splits for each modality
    modal_max = {}
    for split in ['train', 'valid', 'test']:
        if split in packed and isinstance(packed[split], dict):
            for modal_key, arr in packed[split].items():
                if modal_key not in ['labels', 'id']:
                    try:
                        if hasattr(arr, 'shape') and len(arr.shape) >= 2:
                            # If arr is (N, T, D) use T, if (N, D) treat as T=1
                            seq = arr.shape[1] if len(arr.shape) >= 3 else (arr.shape[1] if len(arr.shape) == 2 else 1)
                        else:
                            seq = 1
                    except Exception:
                        seq = 1
                    modal_max[modal_key] = max(modal_max.get(modal_key, 0), seq)

    # Determine a single max_seq_len to use for padding (use the max across modalities)
    if modal_max:
        global_max_seq = max(modal_max.values())
    else:
        global_max_seq = 50

    # Call the original loader which returns (train, valid, test) with global padding
    train, valid, test = get_dataloader(full_path, batch_size=batch_size, num_workers=num_workers, flatten_time_series=False, task=None, max_pad=True, max_seq_len=global_max_seq)
    return train, valid, test


if __name__ == '__main__':
    DATA_DIR = './multibench_data'
    print('--- Loading MOSI sample ---')
    train_loader_mosi, valid_loader_mosi, test_loader_mosi = get_dataloaders('mosi', DATA_DIR, aggregation='mean', batch_size=32, num_workers=0)
    sample_modalities, sample_labels = next(iter(train_loader_mosi))
    print('Sample batch shapes:')
    print([s.shape for s in sample_modalities])