import os

from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, DataCollatorWithPadding, default_data_collator

from .distributed_utils import DistGroups

# copy from run_glue_no_trainer.py
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


def get_glue_dataloader(
    datasets_path,
    task_name,
    model_name_or_path=None,
    pad_to_max_length=False,
    test_size=0.1,
    preprocessing_num_workers=4,
    use_fp16=False,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    max_seq_length=8192,
):
    raw_datasets = load_from_disk(os.path.join(datasets_path, task_name), keep_in_memory=True)
    # Labels
    if task_name is not None:
        is_regression = task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
        else:
            pass
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            pass
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
    if task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None
    # Some models have set the order of the labels to use, so let's make sure we do use it.
    # there have no finetune so have not model.config.label2id

    # if label_to_id is not None:
    #     model.config.label2id = label_to_id
    #     model.config.id2label = {id: label for label, id in config.label2id.items()}
    # elif args.task_name is not None and not is_regression:
    #     model.config.label2id = {l: i for i, l in enumerate(label_list)}
    #     model.config.id2label = {id: label for label, id in config.label2id.items()}

    padding = "max_length" if pad_to_max_length else False

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding=padding, max_length=max_seq_length, truncation=True)

        if "label" in examples:
            # if label_to_id is not None:
            # Map labels to IDs (not necessary for GLUE tasks)
            # result["labels"] = [label_to_id[l] for l in examples["label"]]
            # else:
            # In all cases, rename the column to labels because the model will expect that.
            result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_datasets["train"].column_names,
        num_proc=preprocessing_num_workers,
        desc="Running tokenizer on dataset",
    )

    # train_dataset = processed_datasets["train"]
    splits = processed_datasets["train"].train_test_split(test_size=test_size)
    train_dataset = splits["train"]
    test_dataset = splits["test"]
    eval_dataset = processed_datasets["validation_matched" if task_name == "mnli" else "validation"]
    if pad_to_max_length:
        # If padding was already done ot max length, we use the default data collator that will just convert everything
        # to tensors.
        data_collator = default_data_collator
    else:
        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if use_fp16 else None))
    sampler = DistributedSampler(
        train_dataset, shuffle=True, num_replicas=DistGroups["dp"].size(), rank=DistGroups["dp"].rank()
    )
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=sampler is None,
        sampler=sampler,
        collate_fn=data_collator,
        batch_size=per_device_train_batch_size,
        pin_memory=True,
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        collate_fn=data_collator,
        batch_size=per_device_eval_batch_size,
        pin_memory=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        collate_fn=data_collator,
        batch_size=per_device_eval_batch_size,
        pin_memory=True,
    )
    return train_dataloader, eval_dataloader, test_dataloader, tokenizer
