import os
import json
from abc import *

import torch
import csv
from torch.utils.data import TensorDataset
import numpy as np

from common import DATA_PATH

def create_tensor_dataset(inputs, labels, index):
    assert len(inputs) == len(labels)
    assert len(inputs) == len(index)

    inputs = torch.stack(inputs)  # (N, T)
    labels = torch.stack(labels).unsqueeze(1)  # (N, 1)
    index = np.array(index)
    index = torch.Tensor(index).long()

    dataset = TensorDataset(inputs, labels, index)

    return dataset

class BaseDataset(metaclass=ABCMeta):
    def __init__(self, data_name, total_class, tokenizer, data_ratio=1.0, seed=0):

        self.data_name = data_name
        self.total_class = total_class
        self.root_dir = os.path.join(DATA_PATH, data_name)

        self.tokenizer = tokenizer
        self.data_ratio = data_ratio
        self.seed = seed

        self.n_classes = int(self.total_class)  # Split a given data
        self.class_idx = list(range(self.n_classes))  # all classes
        self.max_class = 1000

        if self.data_ratio < 1.0:
            n_samples = [int(self.max_class * self.data_ratio)] * self.n_classes
            self.n_samples = n_samples
        else:
            self.n_samples = [100000] * self.n_classes

        if not self._check_exists():
            self._preprocess()

        self.train_dataset = torch.load(self._train_path)
        self.val_dataset = torch.load(self._val_path)
        self.test_dataset = torch.load(self._test_path)

    @property
    def base_path(self):
        if self.data_ratio < 1.0:
            base_path = '{}_{}_data_{:.3f}'.format(
                self.data_name, self.tokenizer.name, self.data_ratio, self.seed)
        else:
            base_path = '{}_{}'.format(self.data_name, self.tokenizer.name)

        return base_path

    @property
    def _train_path(self):
        return os.path.join(self.root_dir, self.base_path + '_train.pth')

    @property
    def _val_path(self):
        return os.path.join(self.root_dir, self.base_path + '_val.pth')

    @property
    def _test_path(self):
        return os.path.join(self.root_dir, self.base_path + '_test.pth')

    def _check_exists(self):
        if not os.path.exists(self._train_path):
            return False
        elif not os.path.exists(self._val_path):
            return False
        elif not os.path.exists(self._test_path):
            return False
        else:
            return True

    @abstractmethod
    def _preprocess(self):
        pass

    @abstractmethod
    def _load_dataset(self, *args, **kwargs):
        pass

class NewsDataset(BaseDataset):
    def __init__(self, tokenizer, data_ratio=1.0, seed=0):
        super(NewsDataset, self).__init__('news', 20, tokenizer, data_ratio, seed)

    def _preprocess(self):
        print('Pre-processing news dataset...')
        train_dataset, val_dataset = self._load_dataset('train')
        test_dataset = self._load_dataset('test')

        torch.save(train_dataset, self._train_path)
        torch.save(val_dataset, self._val_path)
        torch.save(test_dataset, self._test_path)

    def _load_dataset(self, mode='train', raw_text=False):
        assert mode in ['train', 'test']

        if mode == 'test':
            source_path = os.path.join(self.root_dir, 'test.txt')
        else:
            source_path = os.path.join(self.root_dir, 'train.txt')

        with open(source_path, encoding='utf-8') as f:
            lines = f.readlines()

        inputs, labels, index = [], [], []
        v_inputs, v_labels, v_index = [], [], []

        # Dummy for selecting all samples
        n_samples_train = [100000] * self.n_classes
        n_samples_val = [100000] * self.n_classes
        n_samples_test = [100000] * self.n_classes

        # Count the number of training examples to construct validation set
        n_samples_train_np = np.zeros(self.n_classes)
        if self.data_ratio == 1:
            for line in lines:
                toks = line.split(',')

                if not int(toks[1]) in self.class_idx:  # only selected classes
                    continue

                label = self.class_idx.index(int(toks[1]))  # convert to subclass index
                n_samples_train_np[label] += 1
                n_samples_val = list(np.round(0.1 * n_samples_train_np))

        num, num_v = 0, 0
        for line in lines:
            toks = line.split(',')

            if not int(toks[1]) in self.class_idx:  # only selected classes
                continue

            path = os.path.join(self.root_dir, '{}'.format(toks[0]))
            with open(path, encoding='utf-8', errors='ignore') as f:
                text = f.read()

            if not raw_text:
                text = self.tokenizer.encode(text, add_special_tokens=True, max_length=128, pad_to_max_length=True,
                                             return_tensors='pt')[0]

            label = self.class_idx.index(int(toks[1]))  # convert to subclass index
            label = torch.tensor(label).long()

            if mode == 'train':
                if n_samples_val[int(label)] > 0:
                    v_inputs.append(text)
                    v_labels.append(label)
                    v_index.append(num_v)

                    n_samples_val[int(label)] -= 1
                    num_v += 1
                elif n_samples_train[int(label)] > 0:
                    inputs.append(text)
                    labels.append(label)
                    index.append(num)

                    n_samples_train[int(label)] -= 1
                    num += 1

            else:
                inputs.append(text)
                labels.append(label)
                index.append(num)

                n_samples_test[int(label)] -= 1
                num += 1

        if raw_text:
            dataset = zip(inputs, labels, index)
        else:
            dataset = create_tensor_dataset(inputs, labels, index)

        if mode == 'train':
            v_dataset = create_tensor_dataset(v_inputs, v_labels, v_index)
            return dataset, v_dataset
        else:
            return dataset