import os
import csv
import json
import logging
import collections
import dataclasses
import numpy as np
import pickle as pkl

from dataclasses import dataclass
from typing import List, Optional, Union
from sklearn.model_selection import KFold

from transformers import is_tf_available, is_torch_available


logger = logging.getLogger(__name__)


@dataclass
class InputExample:
    """
    A single training/test example for simple sequence classification.
    Args:
        guid: Unique id for the example.
        text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
        text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
        label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
    """

    guid: str
    text_a: str
    text_b: Optional[str] = None
    label: Optional[str] = None

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(dataclasses.asdict(self), indent=2) + "\n"


# @dataclass(frozen=True)
@dataclass
class InputFeatures:
    """
    A single set of features of data.
    Property names are the same names as the corresponding inputs to a model.
    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            Usually  ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
        token_type_ids: (Optional) Segment token indices to indicate first and second
            portions of the inputs. Only some models use them.
        label: (Optional) Label corresponding to the input. Int for classification problems,
            float for regression problems.
    """

    guid: List[int]
    input_ids: List[int]
    attention_mask: Optional[List[int]] = None
    token_type_ids: Optional[List[int]] = None
    label: Optional[Union[int, float]] = None

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(dataclasses.asdict(self)) + "\n"


class DataProcessor:
    """Base class for data converters for sequence classification data sets."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """Gets an example from a dict with tensorflow tensors
        Args:
            tensor_dict: Keys and values should match the corresponding Glue
                tensorflow_dataset examples.
        """
        raise NotImplementedError()

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    def tfds_map(self, example):
        """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
        This method converts examples to the correct format."""
        if len(self.get_labels()) > 1:
            example.label = self.get_labels()[int(example.label)]
        return example

    def get_label_dist(self, labels, filename=None):

        num_examples = len(labels)
        counter = collections.Counter(labels)
        if filename is not None:
            filename.write("No. of examples : {}\n".format(num_examples))
            filename.write(
                "Label counts : {}\n".format(
                    {label: counter[label] for label in counter.keys()}
                )
            )
            normalized_counter = {
                label: counter[label] / num_examples for label in counter.keys()
            }
            filename.write("Label distribution : {}\n".format(normalized_counter))
        else:
            print("No. of examples : ", num_examples)
            print(
                "Label counts : ", {label: counter[label] for label in counter.keys()}
            )
            normalized_counter = {
                label: counter[label] / num_examples for label in counter.keys()
            }
            print("Label distribution : ", normalized_counter)

    def process_train_dev(
        self,
        data_dir,
        task_name,
        train_filename="train.tsv",
        dev_filename="dev.tsv",
        test_filename="test.tsv",
        dev_filename1=None,
        test_filename1=None,
        n_dev_examples=1001,
        n_test_examples=1001,
        n_splits=5,
        random_state=42,
        delimiter="\t",
        quotechar=None,
    ):

        lines = []
        labels = []
        for i, row in enumerate(
            self._read_tsv(
                os.path.join(data_dir, train_filename),
                delimiter=delimiter,
                quotechar=quotechar,
            )
        ):
            if i == 0 and task_name in [
                "mnli",
                "mrpc",
                "sst-2",
                "sts-b",
                "qqp",
                "qnli",
                "rte",
                "wnli",
            ]:
                continue
            label = self.get_label(input=row)
            if label is None:
                continue
            labels.append(label)
            lines.append(row)

        num_examples = len(lines)
        # data = [i for i in range(num_examples)]
        # kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

        splits = get_splits(
            n_samples=num_examples,
            n_dev_examples=n_dev_examples,
            n_splits_to_select=n_splits,
            random_state=random_state,
        )

        split_indices = {}

        if not os.path.exists(os.path.join(data_dir, "processed")):
            os.makedirs(os.path.join(data_dir, "processed"))

        log_f = open(os.path.join(data_dir, "processed", "data.log"), "a+")
        log_f.write("Actual train label distribution\n")
        self.get_label_dist(labels=labels, filename=log_f)
        log_f.write("-------------------------------\n")

        # for idx, split_idx in enumerate(kf.split(data)):
        for idx, split_idx in enumerate(splits):
            train_idx, dev_idx = split_idx[0], split_idx[1]
            split_indices[idx] = split_idx

            train_lines = [lines[i] for i in train_idx]
            dev_lines = [lines[i] for i in dev_idx]
            train_labels = [labels[i] for i in train_idx]
            dev_labels = [labels[i] for i in dev_idx]

            f = open(
                os.path.join(data_dir, "processed", "train_v{}.tsv".format(idx + 1)),
                "w",
                encoding="utf-8-sig",
            )
            train_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
            for line in train_lines:
                train_f.writerow(line)
            f.close()

            f = open(
                os.path.join(data_dir, "processed", "dev_v{}.tsv".format(idx + 1)),
                "w",
                encoding="utf-8-sig",
            )
            dev_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
            for line in dev_lines:
                dev_f.writerow(line)
            f.close()

            log_f.write("Train label distribution: {}\n".format(idx + 1))
            self.get_label_dist(labels=train_labels, filename=log_f)

            log_f.write("Dev label distribution: {}\n".format(idx + 1))
            self.get_label_dist(labels=dev_labels, filename=log_f)
            log_f.write("-------------------------------\n")

        pkl.dump(
            split_indices,
            open(os.path.join(data_dir, "processed", "split_indices.pkl"), "wb"),
        )

        dev = self._read_tsv(
            os.path.join(data_dir, dev_filename),
            delimiter=delimiter,
            quotechar=quotechar,
        )
        f = open(
            os.path.join(data_dir, "processed", test_filename),
            "w",
            encoding="utf-8-sig",
        )
        test = csv.writer(f, delimiter=delimiter, quotechar=quotechar)

        test_labels = []

        if task_name in ["yelp", "yahooqa", "qqp", "mnli"]:
            test_examples = []
            labels = []
            for (i, row) in enumerate(dev):
                if i == 0 and task_name in [
                    "mnli",
                    "mrpc",
                    "sst-2",
                    "sts-b",
                    "qqp",
                    "qnli",
                    "rte",
                    "wnli",
                ]:
                    continue
                label = self.get_label(input=row)
                if label is None:
                    continue
                labels.append(label)
                test_examples.append(row)

            print("No. of test examples : ", len(test_examples))
            print("No. of examples to select : ", n_test_examples)
            sampled_indices = np.random.RandomState(random_state).choice(
                len(test_examples), n_test_examples, replace=False
            )

            for idx in sampled_indices:
                test_labels.append(labels[idx])
                test.writerow(test_examples[idx])
            f.close()
        else:
            for (i, row) in enumerate(dev):
                if i == 0 and task_name in [
                    "mnli",
                    "mrpc",
                    "sst-2",
                    "sts-b",
                    "qqp",
                    "qnli",
                    "rte",
                    "wnli",
                ]:
                    continue
                label = self.get_label(input=row)
                if label is None:
                    continue
                test_labels.append(label)
                test.writerow(row)
            f.close()

        log_f.write("Test label distribution\n")
        self.get_label_dist(labels=test_labels, filename=log_f)
        log_f.write("-------------------------------\n")

        if dev_filename1 is not None and test_filename1 is not None:

            dev_mm = self._read_tsv(
                os.path.join(data_dir, dev_filename1),
                delimiter=delimiter,
                quotechar=quotechar,
            )
            f = open(
                os.path.join(data_dir, "processed", test_filename1),
                "w",
                encoding="utf-8-sig",
            )
            test_mm = csv.writer(f, delimiter=delimiter, quotechar=quotechar)

            test_labels = []
            for (i, row) in enumerate(dev_mm):
                if i == 0 and task_name in ["mnli"]:
                    continue
                label = self.get_label(input=row)
                if label is None:
                    continue
                test_labels.append(label)
                test_mm.writerow(row)
            f.close()

            log_f.write("Test label distribution (mis-matched)\n")
            self.get_label_dist(labels=test_labels, filename=log_f)
            log_f.write("-------------------------------\n")
        log_f.close()

    def process_train_dev_lll(
        self,
        data_dir,
        task_name,
        train_filename="train.csv",
        dev_filename="test.csv",
        test_filename="test.tsv",
        n_train_examples=115000,
        n_dev_examples=5000,
        n_test_examples=7600,
        n_splits=5,
        random_state=42,
        delimiter=",",
        quotechar=None,
        processed_dir="processed_lll",
    ):

        lines = []
        labels = []
        for i, row in enumerate(
            self._read_tsv(
                os.path.join(data_dir, train_filename),
                delimiter=delimiter,
                quotechar=quotechar,
            )
        ):
            if i == 0 and task_name in [
                "mnli",
                "mrpc",
                "sst-2",
                "sts-b",
                "qqp",
                "qnli",
                "rte",
                "wnli",
            ]:
                continue
            label = self.get_label(input=row)
            if label is None:
                continue
            labels.append(label)
            lines.append(row)

        num_examples = len(lines)
        # data = [i for i in range(num_examples)]
        # kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

        splits = get_splits(
            n_samples=num_examples,
            n_dev_examples=n_dev_examples,
            n_splits_to_select=n_splits,
            random_state=random_state,
        )

        split_indices = {}

        if not os.path.exists(os.path.join(data_dir, processed_dir)):
            os.makedirs(os.path.join(data_dir, processed_dir))

        log_f = open(os.path.join(data_dir, processed_dir, "data.log"), "a+")
        log_f.write("Actual train label distribution\n")
        self.get_label_dist(labels=labels, filename=log_f)
        log_f.write("-------------------------------\n")

        # for idx, split_idx in enumerate(kf.split(data)):
        for idx, split_idx in enumerate(splits):

            if len(split_idx[0]) > n_train_examples:
                print(
                    "Downsampling train examples as ",
                    len(split_idx[0]),
                    " is more than required ",
                    n_train_examples,
                )
                train_idx = np.random.RandomState(random_state).choice(
                    split_idx[0], size=n_train_examples, replace=False
                )
                dev_idx = split_idx[1]
                split_indices[idx] = (train_idx, dev_idx)
            else:
                train_idx, dev_idx = split_idx[0], split_idx[1]
                split_indices[idx] = split_idx

            train_lines = [lines[i] for i in train_idx]
            dev_lines = [lines[i] for i in dev_idx]
            train_labels = [labels[i] for i in train_idx]
            dev_labels = [labels[i] for i in dev_idx]

            f = open(
                os.path.join(data_dir, processed_dir, "train_v{}.tsv".format(idx + 1)),
                "w",
                encoding="utf-8-sig",
            )
            train_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
            for line in train_lines:
                train_f.writerow(line)
            f.close()

            f = open(
                os.path.join(data_dir, processed_dir, "dev_v{}.tsv".format(idx + 1)),
                "w",
                encoding="utf-8-sig",
            )
            dev_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
            for line in dev_lines:
                dev_f.writerow(line)
            f.close()

            log_f.write("Train label distribution: {}\n".format(idx + 1))
            self.get_label_dist(labels=train_labels, filename=log_f)

            log_f.write("Dev label distribution: {}\n".format(idx + 1))
            self.get_label_dist(labels=dev_labels, filename=log_f)
            log_f.write("-------------------------------\n")

        pkl.dump(
            split_indices,
            open(os.path.join(data_dir, processed_dir, "split_indices.pkl"), "wb"),
        )

        dev = self._read_tsv(
            os.path.join(data_dir, dev_filename),
            delimiter=delimiter,
            quotechar=quotechar,
        )
        f = open(
            os.path.join(data_dir, processed_dir, test_filename),
            "w",
            encoding="utf-8-sig",
        )
        test = csv.writer(f, delimiter=delimiter, quotechar=quotechar)

        test_labels = []

        if task_name in ["yelp", "yahooqa", "amzn", "dbpedia", "qqp", "qnli", "mnli"]:
            test_examples = []
            labels = []

            for (i, row) in enumerate(dev):
                if i == 0 and task_name in [
                    "mnli",
                    "mrpc",
                    "sst-2",
                    "sts-b",
                    "qqp",
                    "qnli",
                    "rte",
                    "wnli",
                ]:
                    continue
                label = self.get_label(input=row)
                if label is None:
                    continue
                labels.append(label)
                test_examples.append(row)

            print("No. of test examples : ", len(test_examples))
            print("No. of examples to select : ", n_test_examples)
            sampled_indices = np.random.RandomState(random_state).choice(
                len(test_examples), n_test_examples, replace=False
            )

            for idx in sampled_indices:
                test_labels.append(labels[idx])
                test.writerow(test_examples[idx])
            f.close()
        else:
            for (i, row) in enumerate(dev):
                if i == 0 and task_name in [
                    "mnli",
                    "mrpc",
                    "sst-2",
                    "sts-b",
                    "qqp",
                    "qnli",
                    "rte",
                    "wnli",
                ]:
                    continue
                label = self.get_label(input=row)
                if label is None:
                    continue

                test_labels.append(label)
                test.writerow(row)
            f.close()

        log_f.write("Test label distribution\n")
        self.get_label_dist(labels=test_labels, filename=log_f)
        log_f.write("-------------------------------\n")
        log_f.close()

    @classmethod
    def _read_tsv(cls, input_file, delimiter="\t", quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8-sig") as f:
            return list(csv.reader(f, delimiter=delimiter, quotechar=quotechar))

    @classmethod
    def _read_jsonl(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8-sig") as f:
            return [json.loads(l) for l in f]


def iter_test_indices(n_samples, fold_size, shuffle=True, random_state=42):

    """Adapted from K-Folds cross-validator (sklearn)"""
    indices = np.arange(n_samples)

    if shuffle:
        np.random.RandomState(random_state).shuffle(indices)

    n_splits = n_samples // fold_size
    fold_sizes = np.full(n_splits, fold_size, dtype=np.int)
    fold_sizes[: n_samples % n_splits] += 1
    current = 0
    for _fold_size in fold_sizes:
        start, stop = current, current + _fold_size
        yield indices[start:stop]
        current = stop


def get_splits(
    n_samples, n_dev_examples, n_splits_to_select=5, shuffle=True, random_state=42
):

    """Adapted from K-Folds cross-validator (sklearn)"""
    indices = np.arange(n_samples)
    n_split = 0
    splits = []

    for test_index in iter_test_indices(
        n_samples=n_samples,
        fold_size=n_dev_examples,
        shuffle=shuffle,
        random_state=random_state,
    ):

        if n_split >= n_splits_to_select:
            break

        test_mask = np.zeros(n_samples, dtype=np.bool)
        test_mask[test_index] = True

        splits.append((indices[np.logical_not(test_mask)], indices[test_mask]))
        n_split += 1

    return splits
