import numpy as np
import pandas as pd
import torch

import os, sys

from .datasets import dataset_register


@dataset_register
class JigsawDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, csv_file, label_key, *, split='train', clean_label_file=None, threshold=0.5):
        """
        Parameters:
        - csv_file:     the csv file of all data
        - label_key:    the key of labels
        - root_dir:     the root directory of dataset files
        - split:        data split, could be either train or test
        - threshold:    confidence threshold for corrupted labels
        """

        # load all_data.csv
        data = pd.read_csv(os.path.join(root_dir, csv_file))
        self.text = data.loc[data.split == split, "comment_text"]
        self.text.fillna("NULL", inplace=True) # impute missing values

        self.raw_labels = data.loc[data.split == split, label_key]
        self.raw_labels.mask(self.raw_labels >= threshold, 1, inplace=True)
        self.raw_labels.mask(self.raw_labels < threshold, 0, inplace=True)
        self.clean_labels = self.raw_labels.copy(deep=True)

        if clean_label_file is not None:
            with open(os.path.join(root_dir, clean_label_file), 'rb') as f:
                label_curation = np.load(f)
                label_curation = label_curation[np.nonzero(label_curation[:, 2] >= threshold)]
                idx, clean_labels = label_curation[:, 0], label_curation[:, 1]

                self.clean_labels = data.loc[label_key]
                # for index consistency, should be
                # self.clean_labels[data.index == idx] = clean_labels
                self.clean_labels[idx] = clean_labels
                self.clean_labels = self.clean_labels[data.split==split]

            assert self.clean_labels.shape == self.raw_labels.shape, f"The shape of clean labels {self.clean_labels.shape} does not match the shape of raw labels {self.raw_labels.shape}"
            
    
    def __len__(self):
        return len(self.clean_labels)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sample = {
            "text": self.text.iloc[idx],
            "raw_label": self.raw_labels.iloc[idx],
            "clean_label": self.clean_labels.iloc[idx]
        }

        return sample
    
    @staticmethod
    def download_data(root_dir):
        raise NotImplementedError