import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch

from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

from .datasets import dataset_register

@dataset_register
class DictDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dir: str,
        file_name: str,
        threshold: float = 0.,
        mode: str = 'replace'
    ):  
        dataset = torch.load(os.path.join(root_dir, file_name))
        self.text = dataset['feature']
        self.raw_labels = np.array(dataset['label'])
        self.clean_labels = np.where(np.array(dataset['label_conf']) >= threshold, np.array(dataset['label_curated']), np.array(dataset['label']))
        print(f"threshold for confidence is {threshold}")
        self.raw_idx = np.arange(len(self.text))
        print(self.text[0])

    def __len__(self):
        return len(self.raw_idx)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        idx = self.raw_idx[idx]
        sample = {
            "raw_idx": idx,
            "text": self.text[idx],
            "raw_label": self.raw_labels[idx],
            "clean_label": self.clean_labels[idx]
        }
        return sample

    def drop_disagreement(self):
        print("Drop disagreement!")
        self.raw_idx = self.raw_idx[self.raw_labels == self.clean_labels]