import os
import json
from abc import *

import torch
import csv
from torch.utils.data import TensorDataset
import numpy as np
import pandas as pd

from datasets import load_dataset
DATA_PATH = './datasets'

def create_tensor_dataset(inputs, labels, index):
    assert len(inputs) == len(labels)
    assert len(inputs) == len(index)

    inputs = torch.stack(inputs)  # (N, T)
    labels = torch.stack(labels)
    index = np.array(index)
    index = torch.Tensor(index).long()

    dataset = TensorDataset(inputs, labels, index)

    return dataset

class BaseDataset(metaclass=ABCMeta):
    def __init__(self, data_name, tokenizer, seed=0):

        self.data_name = data_name
        self.tokenizer = tokenizer
        self.seed = seed

        if not self._check_exists():
            self._preprocess()

        self.train_dataset = torch.load(self._train_path)
        self.val_dataset = torch.load(self._val_path)
        self.test_dataset = torch.load(self._test_path)

    @property
    def _train_path(self):
        return os.path.join(DATA_PATH, self.data_name + '_train.pth')

    @property
    def _val_path(self):
        return os.path.join(DATA_PATH, self.data_name + '_val.pth')

    @property
    def _test_path(self):
        return os.path.join(DATA_PATH, self.data_name + '_test.pth')

    def _check_exists(self):
        if not os.path.exists(self._train_path):
            return False
        elif not os.path.exists(self._val_path):
            return False
        elif not os.path.exists(self._test_path):
            return False
        else:
            return True

    @abstractmethod
    def _preprocess(self):
        pass

    @abstractmethod
    def _load_dataset(self, *args, **kwargs):
        pass

class DynaSent2(BaseDataset):
    def __init__(self, data_name, tokenizer, seed=0):
        super(DynaSent2, self).__init__(data_name, tokenizer, seed)

        self.data_name = data_name

    def _preprocess(self):
        print('Pre-processing {} dataset...'.format(self.data_name))
        train_dataset = self._load_dataset('train')
        val_dataset = self._load_dataset('validation')
        test_dataset = self._load_dataset('test')

        # Use the same dataset for validation and test
        torch.save(train_dataset, self._train_path)
        torch.save(val_dataset, self._val_path)
        torch.save(test_dataset, self._test_path)

    def _load_dataset(self, mode='train'):
        assert mode in ['train', 'validation', 'test']

        data_set = load_dataset("dynabench/dynasent", "dynabench.dynasent.r2.all")[mode]

        # Get the lists of sentences and their labels.
        inputs, labels, indices = [], [], []
        all_sentence = data_set['sentence']
        all_labels = data_set['gold_label']

        for i in range(len(data_set)):
            sent = all_sentence[i]
            toks = self.tokenizer.encode(sent, add_special_tokens=True, max_length=256, pad_to_max_length=True,
                                 return_tensors='pt')[0]

            if all_labels[i] == 'negative':
                label = torch.Tensor([0]).long()
            elif all_labels[i] == 'positive':
                label = torch.Tensor([1]).long()
            else:
                label = torch.Tensor([2]).long()

            inputs.append(toks)
            labels.append(label)
            indices.append(i)

        dataset = create_tensor_dataset(inputs, labels, indices)
        return dataset