#! -*- coding: utf-8
import os.path as path
import typing

import numpy as np
import torch
import torchvision

__all__ = ["load_glue_dataset"]

_INPUT_FEATURES = {
    "cola": ["sentence"],
    "mnli": ["premise", "hypothesis"],
    "mrpc": ["sentence1", "sentence2"],
    "qnli": ["question", "sentence"],
    "qqp": ["question1", "question2"],
    "rte": ["sentence1", "sentence2"],
    "sst2": ["sentence"],
    "stsb": ["sentence1", "sentence2"],
    "wnli": ["sentence1", "sentence2"],
}


def load_glue_dataset(name: str, *args, datadir: str = None, **kwargs) -> typing.Sequence[torch.utils.data.Dataset]:
    name = name.lower()
    task = kwargs.get("task", None).lower()
    tokenizer_config: typing.Dict = kwargs.get("tokenizer_args", {})
    assert task in _INPUT_FEATURES  # unsupported mnli
    from datasets import load_dataset
    from transformers import AutoTokenizer

    trains = load_dataset(name, task, split="train", cache_dir=datadir)
    evals = [load_dataset(name, task, split=split, cache_dir=datadir)
             for split in (["validation"] if task != "mnli" else ["validation_matched", "validation_mismatched"])]
    tokenizer_name = tokenizer_config.pop("name", "bert-base-uncased")
    max_length = tokenizer_config.pop("max_length", 128)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
                                              **tokenizer_config)

    # multi sentence task.
    input_columns = kwargs.get("input_columns",
                               ["input_ids", "token_type_ids", "attention_mask", "labels"])
    feature_cols = _INPUT_FEATURES[task]
    trains = trains.map(lambda d: tokenizer(*[d[feature] for feature in feature_cols],
                                            truncation=True, padding="max_length", max_length=max_length),
                        batched=True)
    trains = trains.rename_column("label", "labels")
    if task != "stsb":
        setattr(trains, "targets", trains["labels"])
    else:
        targets = np.array(trains["labels"])
        targets = (targets / 5).clip(None, 0.99999999)
        stsb_dist_config: typing.Dict = kwargs.get("stsb_distribute", {})
        nseparete = stsb_dist_config.get("n", 1)
        targets = (targets*nseparete).astype(int).tolist()
        setattr(trains, "targets", targets)
    trains.set_format(type="torch", columns=input_columns)

    evals = [e.map(lambda d: tokenizer(*[d[feature] for feature in feature_cols],
                                       truncation=True, padding="max_length", max_length=max_length),
                   batched=True)
             for e in evals]
    evals = [e.rename_column("label", "labels") for e in evals]
    if len(evals) == 1:
        evals = evals[0]
        if task != "stsb":
            targets = evals["labels"]
        else:
            targets = np.array(evals["labels"])
            targets = (targets / 5).clip(None, 0.99999999)
            stsb_dist_config: typing.Dict = kwargs.get("stsb_distribute", {})
            nseparete = stsb_dist_config.get("n", 1)
            targets = (targets*nseparete).astype(int)
        evals.set_format(type="torch", columns=input_columns)
    else:
        if task != "stsb":
            targets = np.concatenate([e["labels"] for e in evals]).tolist()
        else:
            stsb_dist_config: typing.Dict = kwargs.get("stsb_distribute", {})
            nseparete = stsb_dist_config.get("n", 1)
            targets = []
            for e in evals:
                t = np.array(e["labels"])
                t = (t / 5).clip(None, 0.99999999)
                targets.extend((t*nseparete).astype(int).tolist())

        for e in evals:
            e.set_format(type="torch", columns=input_columns)
        evals = torch.utils.data.ConcatDataset(evals)
    setattr(evals, "targets", targets)

    return trains, evals
