import torch
import datasets

# TODO: remove this line after everything debugged
datasets.disable_caching()
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

DATASET_CONFIGS = {
    # Short response
    "sciq": ("allenai/sciq", None, ("train", "test"), 1.0, 1.0),
    "quartz": ("allenai/quartz", None, ("train", "test"), 1.0, 1.0),
    "quail": ("quail", None, ("train", "validation"), 0.5, 1.0),
    # Long response
    "hellaswag": ("Rowan/hellaswag", None, ("train", "validation"), 0.25, 1.0),
    "winogrande": (
        "winogrande",
        "winogrande_debiased",
        ("train", "validation"),
        1.0,
        1.0,
    ),
    "arc": ("allenai/ai2_arc", "ARC-Challenge", ("train", "test"), 1.0, 1.0),
}

DIFFICULTY_DATASET_ROOT = "NeuripsEnsemW2S/"


def load_sft_dataset(sargs):
    if not sargs.is_easy_to_hard:
        dataset = load_dataset(
            DATASET_CONFIGS[sargs.dataset_name][0],
            name=DATASET_CONFIGS[sargs.dataset_name][1],
            cache_dir="./cache",
        )
        train_dataset, eval_dataset = (
            dataset[DATASET_CONFIGS[sargs.dataset_name][2][0]],
            dataset[DATASET_CONFIGS[sargs.dataset_name][2][1]],
        )
        dataset_splits = train_dataset.train_test_split(
            test_size=0.5 * DATASET_CONFIGS[sargs.dataset_name][3], seed=sargs.seed
        )
        train_dataset, transfer_dataset = (
            dataset_splits["test"],
            dataset_splits["train"],
        )
        transfer_dataset = transfer_dataset.select(
            range(min(len(transfer_dataset), len(train_dataset)))
        )
    else:
        original_dataset = load_dataset(
            DATASET_CONFIGS[sargs.dataset_name][0],
            name=DATASET_CONFIGS[sargs.dataset_name][1],
            cache_dir="./cache",
        )
        eval_size = len(original_dataset[DATASET_CONFIGS[sargs.dataset_name][2][1]])
        dataset = load_dataset(
            DIFFICULTY_DATASET_ROOT + sargs.dataset_name,
            cache_dir="./cache",
        )["train"].sort("Diff_rating")
        train_dataset = dataset.select(range((len(dataset) - eval_size) // 2))
        transfer_dataset = dataset.select(
            range((len(dataset) - eval_size) // 2, len(dataset) - eval_size)
        )
        transfer_dataset = transfer_dataset.select(
            range(min(len(transfer_dataset), len(train_dataset)))
        )
        eval_dataset = dataset.select(range(len(dataset) - eval_size, len(dataset)))

    if sargs.test_limit >= 0:
        train_dataset = train_dataset.select(
            range(min(len(train_dataset), sargs.test_limit))
        )
        eval_dataset = eval_dataset.select(
            range(min(len(eval_dataset), sargs.test_limit))
        )
        transfer_dataset = transfer_dataset.select(
            range(min(len(transfer_dataset), sargs.test_limit))
        )

    return train_dataset, transfer_dataset, eval_dataset


def load_pretrain_model_tokenizer(sargs, accelerator, mode="weak"):
    model = AutoModelForCausalLM.from_pretrained(
        sargs.model_name if mode == "weak" else sargs.strong_model_name,
        torch_dtype=torch.bfloat16,
        device_map={"": accelerator.local_process_index},
        trust_remote_code=False,
        use_cache=True,
        cache_dir="./cache",
    )
    tokenizer = AutoTokenizer.from_pretrained(
        sargs.model_name if mode == "weak" else sargs.strong_model_name,
        model_max_length=sargs.model_max_length,
        padding=True,
        truncation=True,
        # Important, do not change
        padding_side="left",
        use_fast=True,
        trust_remote_code=False,
        cache_dir="./cache",
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    return model, tokenizer
