import os
import pickle
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data.dataset import Dataset

from sklearn.model_selection import train_test_split

class AffectDataset(Dataset):
    def __init__(self, dataset_path, dataset='mosei_senti', split_type='train', device=torch.device('cuda'), if_align=True, labeled_ratio=None, classification=False, train_classifier_only=False, transfer=False):
        super(AffectDataset, self).__init__()
        # These are torch tensors
        self.split_type = split_type
        with open(os.path.join(dataset_path, f'{dataset}_data.pkl'), 'rb') as f:
            data = pickle.load(f)[split_type]
            if transfer and dataset == 'mosei':
                data = self._transfer_mosei_to_mosi(data, dataset_path)
        self.vision = torch.from_numpy(data['vision']).float()
        self.text = torch.from_numpy(data['text']).float()
        self.audio = torch.from_numpy(data['audio']).float()
        self.labels = torch.from_numpy(data['labels']).float()
        self.classification = classification
        self.train_classifier_only = train_classifier_only
        if classification:
            # discretize the labels into 7 classes
            self.labels = self.labels.floor().squeeze() + 3

        if split_type == 'train' and labeled_ratio < 1.0:
            self.masks = self._generate_mask(labeled_ratio)
        else:
            self.masks = torch.ones(self.labels.shape[0], dtype=torch.bool)
        
        self.meta = None
        self.data = data
        self.n_modalities = 3  # vision/ text/ audio
        self.device = device

    def get_n_modalities(self):
        return self.n_modalities

    def get_seq_len(self):
        return self.text.shape[1], self.audio.shape[1], self.vision.shape[1]

    def get_dim(self):
        return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]

    def get_lbl_info(self):
        # return number_of_labels, label_dim
        return self.labels.shape[1], self.labels.shape[2]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        X = (index, self.text[index], self.audio[index], self.vision[index])
        Y = (self.labels[index], self.masks[index])
        META = (0, 0, 0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2])
        if self.data == 'mosi':
            META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'),
                    self.meta[index][2].decode('UTF-8'))
        if self.data == 'iemocap':
            Y = torch.argmax(Y, dim=-1)
        return X, Y, META
    
    def _generate_mask(self, mask_ratio):
        if not self.classification:
            labels = self.labels.squeeze().floor() + 3
        else:
            labels = self.labels.squeeze()
        if mask_ratio == 0.0 and self.train_classifier_only:
            # the supervised settings, trained the classifier only - have to set back to 0.05 to ensure fainess
            # mask_ratio = 0.05
            return torch.ones(self.labels.shape[0], dtype=torch.bool)
        elif mask_ratio == 0.0:
            return torch.zeros(self.labels.shape[0], dtype=torch.bool)
        idx = np.arange(self.labels.shape[0])
        id_labels, _ = train_test_split(idx, train_size=mask_ratio, stratify=labels)
        mask = torch.zeros(self.labels.shape[0], dtype=torch.bool)
        mask[id_labels.astype(int)] = True
        print("Labelled distribution: ", np.unique(labels[id_labels.astype(int)], return_counts=True))
        return mask
    
    def _transfer_mosei_to_mosi(self, data, save_path):
        if not os.path.exists(os.path.join(save_path, f'mosei_trans_{self.split_type}.pkl')):
            from sklearn.manifold import TSNE
            audio_tsne = TSNE(n_components=5, verbose=1, perplexity=40, n_iter=250, method='exact')
            N, T, D = data['audio'].shape
            data_aud = np.zeros((N, T, 5))
            for i in range(N):
                data_aud[i] = audio_tsne.fit_transform(data['audio'][i])
            data['audio'] = data_aud
            print("Audio shape: ", data['audio'].shape)

            vision_tsne = TSNE(n_components=20, verbose=1, perplexity=40, n_iter=250, method='exact')
            N, T, D = data['vision'].shape
            data_vis = np.zeros((N, T, 20))
            for i in range(N):
                data_vis[i] = vision_tsne.fit_transform(data['vision'][i])
            data['vision'] = data_vis
            print("Vision shape: ", data['vision'].shape)

            with open(os.path.join(save_path, f'mosei_trans_{self.split_type}.pkl'), 'wb') as f:
                pickle.dump(data, f)
        else:
            with open(os.path.join(save_path, f'mosei_trans_{self.split_type}.pkl'), 'rb') as f:
                data = pickle.load(f)
        return data