import os
import pickle
import torch
import torch.nn.functional as F
import numpy as np
import pyarrow as pa
import random

from sklearn.model_selection import train_test_split
from torch.utils.data.dataset import Dataset

class HatememesDataset(Dataset):
    def __init__(self, dataset_path, data='mosei_senti', split_type='train', device=torch.device('cuda'), labeled_ratio=None, train_classifier_only=False):
        super(HatememesDataset, self).__init__()
        
        self.data_dir = dataset_path
        self.table = pa.ipc.RecordBatchFileReader(
                    pa.memory_map(f"{dataset_path}/hatememes_{split_type}.arrow", "r")
                ).read_all()
        self.train_classifier_only = train_classifier_only
        # this is still numpy array
        self.labels = self.table['label'].to_pandas().to_list()
        self.labels = np.array([np.array(l) for l in self.labels])
        # import IPython; IPython.embed(); exit(0)
        # generate label mask
        if labeled_ratio is not None and split_type == 'train' and labeled_ratio < 1.0:
            self.masks = self._generate_mask(labeled_ratio)
        else:
            self.masks = torch.ones(self.labels.shape[0], dtype=torch.bool)
        
        self.data = data
        self.n_modalities = 2  # text/ image
        self.device = device
        self.split_type = split_type

    def get_n_modalities(self):
        return self.n_modalities

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        path_i = self.table['latent_path'][index].as_py().split('/')[-1]
        dat_i = torch.load(os.path.join(self.data_dir, 'latent_reps', path_i))

        X = (dat_i['image'].detach(), dat_i['text'].detach())
        Y = (torch.tensor(self.labels[index]), self.masks[index])
        
        return X, Y
    
    def _generate_mask(self, mask_ratio):
        if mask_ratio == 0.0 and self.train_classifier_only:
            # the supervised settings, trained the classifier only - have to set back to 0.05 to ensure fainess
            # mask_ratio = 0.05
            return torch.ones(self.labels.shape[0], dtype=torch.bool)
        elif mask_ratio == 0.0:
            return torch.zeros(self.labels.shape[0], dtype=torch.bool)
        
        idx = np.arange(self.labels.shape[0])
        id_labels, _ = train_test_split(idx, train_size=mask_ratio, stratify=self.labels)
        mask = torch.zeros(self.labels.shape[0], dtype=torch.bool)
        mask[id_labels.astype(int)] = True
        print("Labelled distribution: ", np.unique(self.labels[id_labels.astype(int)], return_counts=True))
        return mask