#! -*- coding: utf-8
import os.path as path
import typing

import numpy as np
import torch
import torchvision

__all__ = ["load_glue_dataset"]

_INPUT_FEATURES = {
    "cola": ["sentence"],
    "mnli": ["premise", "hypothesis"],
    "mrpc": ["sentence1", "sentence2"],
    "qnli": ["question", "sentence"],
    "qqp": ["question1", "question2"],
    "rte": ["sentence1", "sentence2"],
    "sst2": ["sentence"],
    "stsb": ["sentence1", "sentence2"],
    "wnli": ["sentence1", "sentence2"],
}

"""
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
name,task = "glue", "wnli"
cache_dir="datas/glue/wnli"
trains = load_dataset(name, task, split="train", cache_dir=cache_dir)
tokenizer_name, max_length = "bert-base-uncased", 128
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

feature_cols = ["sentence1", "sentence2"]
trains = trains.map(lambda d: tokenizer(*[d[feature] for feature in feature_cols],
                                            truncation=True, padding="max_length", max_length=max_length),
                    batched=True)

input_columns = ["input_ids", "token_type_ids", "attention_mask", "labels"]
trains = trains.rename_column("label", "labels")
trains.set_format(type="torch", columns=input_columns)


evals = [load_dataset(name, task, split=split, cache_dir=cache_dir)
            for split in (["validation"] if task != "mnli" else ["validation_matched", "validation_mismatched"])]
evals = [e.map(lambda d: tokenizer(*[d[feature] for feature in feature_cols],
                                    truncation=True, padding="max_length", max_length=max_length),
                batched=True)
            for e in evals]
evals = [e.rename_column("label", "labels") for e in evals]
if len(evals) == 1:
    evals = evals[0]
    targets = evals["labels"] if name != "stsb" else [0] * len(evals)
    evals.set_format(type="torch", columns=input_columns)
else:
    targets = np.concatenate([e["labels"] if name != "stsb" else [0] * len(e)
                                for e in evals]).tolist()
    for e in evals:
        e.set_format(type="torch", columns=input_columns)
    evals = torch.utils.data.ConcatDataset(evals)
"""


def load_glue_dataset(name: str, *args, datadir: str = None, **kwargs) -> typing.Sequence[torch.utils.data.Dataset]:
    name = name.lower()
    task = kwargs.get("task", None).lower()
    tokenizer_config: typing.Dict = kwargs.get("tokenizer_args", {})
    assert task in _INPUT_FEATURES  # unsupported mnli
    from datasets import load_dataset
    from transformers import AutoTokenizer

    trains = load_dataset(name, task, split="train", cache_dir=datadir)
    evals = [load_dataset(name, task, split=split, cache_dir=datadir)
             for split in (["validation"] if task != "mnli" else ["validation_matched", "validation_mismatched"])]
    tokenizer_name = tokenizer_config.pop("name", "bert-base-uncased")
    max_length = tokenizer_config.pop("max_length", 128)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
                                              **tokenizer_config)

    # multi sentence task.
    input_columns = kwargs.get("input_columns",
                               ["input_ids", "token_type_ids", "attention_mask", "labels"])
    feature_cols = _INPUT_FEATURES[task]
    trains = trains.map(lambda d: tokenizer(*[d[feature] for feature in feature_cols],
                                            truncation=True, padding="max_length", max_length=max_length),
                        batched=True)
    # trains = trains.map(lambda d: tokenizer(d[feature_cols[0]] if len(feature_cols) == 1 else tuple([d[feature] for feature in feature_cols]),
    #                                         truncation=True, padding="max_length", max_length=max_length),
    #                     batched=len(feature_cols) == 1)
    trains = trains.rename_column("label", "labels")
    setattr(trains, "targets",
            trains["labels"] if name != "stsb" else [0] * len(trains))
    trains.set_format(type="torch", columns=input_columns)

    evals = [e.map(lambda d: tokenizer(*[d[feature] for feature in feature_cols],
                                       truncation=True, padding="max_length", max_length=max_length),
                   batched=True)
             for e in evals]
    # evals = evals.map(lambda d: tokenizer(d[feature_cols[0]] if len(feature_cols) == 1 else tuple([d[feature] for feature in feature_cols]),
    #                                       truncation=True, padding="max_length", max_length=max_length),
    #                   batched=len(feature_cols) == 1)
    evals = [e.rename_column("label", "labels") for e in evals]
    if len(evals) == 1:
        evals = evals[0]
        targets = evals["labels"] if name != "stsb" else [0] * len(evals)
        evals.set_format(type="torch", columns=input_columns)
    else:
        targets = np.concatenate([e["labels"] if name != "stsb" else [0] * len(e)
                                  for e in evals]).tolist()
        for e in evals:
            e.set_format(type="torch", columns=input_columns)
        evals = torch.utils.data.ConcatDataset(evals)
    setattr(evals, "targets", targets)

    return trains, evals
