import os
import pickle
import torch
import torch.nn.functional as F
import numpy as np
import pyarrow as pa
import random

from skmultilearn.model_selection import iterative_train_test_split
from torch.utils.data.dataset import Dataset

class MMIMDBDataset(Dataset):
    def __init__(self, dataset_path, data='mosei_senti', split_type='train', device=torch.device('cuda'), if_align=True, labeled_ratio=None):
        super(MMIMDBDataset, self).__init__()
        
        self.data_dir = dataset_path
        self.table = pa.ipc.RecordBatchFileReader(
                    pa.memory_map(f"{dataset_path}/mmimdb_{split_type}.arrow", "r")
                ).read_all()
        # this is still numpy array
        self.labels = self.table['label'].to_pandas().to_list()
        self.labels = np.array([np.array(l) for l in self.labels])
        # import IPython; IPython.embed(); exit(0)
        # generate label mask
        if labeled_ratio is not None and split_type == 'train':
            self.masks = self._generate_mask(labeled_ratio)
        else:
            self.masks = torch.ones(self.labels.shape[0], dtype=torch.bool)
        
        self.data = data
        self.n_modalities = 2  # text/ image
        self.device = device
        self.split_type = split_type

    def get_n_modalities(self):
        return self.n_modalities

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        path_i = self.table['latent_path'][index].as_py().split('/')[-1]
        dat_i = torch.load(os.path.join(self.data_dir, 'latent_reps', path_i))

        # randomly choose a plot from the given image
        rand_idx = random.choice(range(len(dat_i['text'])))
        X = (dat_i['image'], dat_i['text'][rand_idx])
        Y = (torch.from_numpy(self.labels[index]), self.masks[index])
        
        return X, Y
    
    def _generate_mask(self, mask_ratio):
        # mask = torch.rand(self.labels.shape[0]) <= mask_ratio
        _, y_labelled, _, _ = iterative_train_test_split(self.labels, self.labels, test_size=1-mask_ratio)
        mask = torch.zeros(self.labels.shape[0], dtype=torch.bool)

        for i in range(y_labelled.shape[0]):
            m_label = np.where(np.all((self.labels - y_labelled[i:i+1]) == 0, axis=1))[0]
            for j in range(5):
                rand_j = random.choice(range(len(m_label)))
                mask[m_label[rand_j]] = True
        return mask