from transformers import AutoTokenizer
from datasets import load_dataset
from functools import partial

def load_tokenized_downstream_dataset(task_name: str, tokenizer: AutoTokenizer, max_length: int):
    train_dataset, test_dataset = load_downstream_dataset(task_name)
    if task_name == "pyqa":
        tokenize_fn =  partial(_tokenize_pyqa, tokenizer=tokenizer, max_length=max_length)
    elif task_name == "logiqa":
        tokenize_fn =  partial(_tokenize_logiqa, tokenizer=tokenizer, max_length=max_length)
    elif task_name in ["arc-e", "arc-c"]:
        tokenize_fn = partial(_tokenize_arc, tokenizer=tokenizer, max_length=max_length)
    elif task_name == "winogrande":
        tokenize_fn = partial(_tokenize_winogrande, tokenizer=tokenizer, max_length=max_length)
    elif task_name == "mmlu":
        tokenize_fn = partial(_tokenize_mmlu, tokenizer=tokenizer, max_length=max_length)
    elif task_name == "wsc":
        tokenize_fn = partial(_tokenize_wsc, tokenizer=tokenizer, max_length=max_length)
    else:
        raise ValueError(f"Invalid task name: {task_name}")
    
    train_dataset = train_dataset.map(tokenize_fn, batched=True)
    test_dataset = test_dataset.map(tokenize_fn, batched=True)

    return train_dataset, test_dataset

def load_downstream_dataset(task_name: str):
    if task_name == "pyqa":
        train_dataset = load_dataset("ybisk/piqa", split="train", trust_remote_code=True)
        test_dataset = load_dataset("ybisk/piqa", split="validation", trust_remote_code=True)
    elif task_name == "logiqa":
        train_dataset = load_dataset("EleutherAI/logiqa", split="train", trust_remote_code=True)
        test_dataset = load_dataset("EleutherAI/logiqa", split="test", trust_remote_code=True)
    elif task_name == "arc-e":
        train_dataset = load_dataset("ai2_arc", "ARC-Easy", split="train", trust_remote_code=True)
        test_dataset = load_dataset("ai2_arc", "ARC-Easy", split="test", trust_remote_code=True)
    elif task_name == "arc-c":
        train_dataset = load_dataset("ai2_arc", "ARC-Challenge", split="train", trust_remote_code=True)
        test_dataset = load_dataset("ai2_arc", "ARC-Challenge", split="test", trust_remote_code=True)
    elif task_name == "winogrande":
        train_dataset = load_dataset("winogrande", "winogrande_xl", split="train", trust_remote_code=True)
        test_dataset = load_dataset("winogrande", "winogrande_xl", split="validation", trust_remote_code=True)
    elif task_name == "mmlu":
        train_dataset = load_dataset("cais/mmlu", "all", split="auxiliary_train", trust_remote_code=True)
        test_dataset = load_dataset("cais/mmlu", "all", split="test", trust_remote_code=True)
    elif task_name == "wsc":
        train_dataset = load_dataset("super_glue", "wsc.fixed", split="train", trust_remote_code=True)
        test_dataset = load_dataset("super_glue", "wsc.fixed", split="validation", trust_remote_code=True)
    else:
        raise ValueError(f"Invalid task name: {task_name}")
    
    return train_dataset, test_dataset

def _tokenize_pyqa(examples, tokenizer: AutoTokenizer, max_length: int):
    context_name = "goal"
    ending_names = [f"sol{i+1}" for i in range(2)]
    first_sentences = [[context] * 2 for context in examples[context_name]]
    second_sentences = [
        [f"{examples[end][i]}" for end in ending_names] for i in range(len(examples[context_name]))
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        padding="max_length", 
        max_length=max_length
    )

    out = {k: [v[i:i+2] for i in range(0, len(v), 2)] for k, v in tokenized_examples.items()}

    return out

def _tokenize_logiqa(examples, tokenizer: AutoTokenizer, max_length: int):
    first_sentences = [[context + " " + question] * 4 for context, question in zip(examples['context'], examples['question'])]
    second_sentences = [options + [""] * (4-len(options)) for options in examples['options']]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        padding="max_length", 
        max_length=max_length
    )

    out = {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
    
    out['label'] = ["abcd".index(x) for x in examples['label']]

    return out

def _tokenize_arc(examples, tokenizer: AutoTokenizer, max_length: int):
    first_sentences = [[question] * 5 for question in examples['question']]
    second_sentences = [choices['text'] + [""] * (5-len(choices['text'])) for choices in examples['choices']]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        padding="max_length", 
        max_length=max_length
    )

    out = {k: [v[i:i+5] for i in range(0, len(v), 5)] for k, v in tokenized_examples.items()}

    out['label'] = [choices['label'].index(x) for x, choices in zip(examples['answerKey'], examples['choices'])]

    return out

def _tokenize_winogrande(examples, tokenizer: AutoTokenizer, max_length: int):
    first_sentences = [[sentence] * 2 for sentence in examples['sentence']]
    second_sentences = [[option1, option2] for option1, option2 in zip(examples['option1'], examples['option2'])]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        padding="max_length", 
        max_length=max_length
    )

    out = {k: [v[i:i+2] for i in range(0, len(v), 2)] for k, v in tokenized_examples.items()}

    for i, x in enumerate(examples['answer']):
        if x not in ["1", "2"]:
            print(f"Invalid answer: {x} at index {i}")

    out['label'] = [int(answer) - 1 for answer in examples['answer']]

    return out

def _tokenize_mmlu(examples, tokenizer: AutoTokenizer, max_length: int):
    first_sentences = [[question] * 4 for question in examples['question']]
    second_sentences = [choices for choices in examples['choices']]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        padding="max_length", 
        max_length=max_length
    )

    out = {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

    out['label'] = examples['answer']

    return out

def _tokenize_wsc(examples, tokenizer: AutoTokenizer, max_length: int):
    first_sentences = [[f"{sentence if sentence[-1] == '.' else sentence + '.'}"] * 2 for sentence in examples['text']]
    second_sentences = [[
        f"\"{span2}\" does not refer to \"{span1}\".", 
        f"\"{span2}\" refers to \"{span1}\"."
    ] for span1, span2 in zip(examples['span1_text'], examples['span2_text'])]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        padding="max_length", 
        max_length=max_length, 
        truncation=True,
    )

    out = {k: [v[i:i+2] for i in range(0, len(v), 2)] for k, v in tokenized_examples.items()}

    out['label'] = [int(answer) for answer in examples['label']]

    return out