import evaluate
import frozendict
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from evaluate import load

from scipy.special import softmax

dataset = load_dataset("super_glue", "record")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_function(examples):
    texts = []
    labels = []
    original_idxs = []
    entity_idxs = []
    entities = []
    has_labels = "answers" in examples
    answers_list = examples["answers"] if has_labels else [[] for _ in range(len(examples["passage"]))]
    for ex_idx in range(len(examples["passage"])):
        passage = examples["passage"][ex_idx]
        query = examples["query"][ex_idx]
        curr_entities = examples["entities"][ex_idx]
        curr_answers = answers_list[ex_idx]
        for ent_idx, entity in enumerate(curr_entities):
            text = passage + tokenizer.sep_token + query.replace("@placeholder", entity)
            texts.append(text)
            original_idxs.append(frozendict.frozendict(examples["idx"][ex_idx]))
            entity_idxs.append(ent_idx)
            entities.append(entity)
            if has_labels:
                labels.append(1 if entity in curr_answers else 0)
    result = {"text": texts, "original_idx": original_idxs, "entity_idx": entity_idxs, "entity": entities}
    if has_labels:
        result["labels"] = labels
    return result

tokenized_train = dataset["train"].map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
tokenized_val = dataset["validation"].map(preprocess_function, batched=True, remove_columns=dataset["validation"].column_names)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_train = tokenized_train.map(tokenize_function, batched=True)
tokenized_val = tokenized_val.map(tokenize_function, batched=True)

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=10,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = softmax(logits, axis=-1)[:, 1]
    original_idxs = tokenized_val["original_idx"]
    entities = tokenized_val["entity"]
    original_idxs = [frozendict.frozendict(oi) for oi in original_idxs]
    # Create a dictionary to store the most likely prediction for each idx
    pred_dict = {}
    prob_dict = {}  # Track highest probability for each idx
    for i, o_idx in enumerate(original_idxs):
        if probs[i] > 0.5:  # Only consider predictions above threshold
            if o_idx not in pred_dict or probs[i] > prob_dict[o_idx]:
                pred_dict[o_idx] = entities[i]  # Store entity with highest probability
                prob_dict[o_idx] = probs[i]  # Update highest probability

    # Get unique indices, sorted by passage and query
    unique_idxs = sorted(set(original_idxs), key=lambda x: (x['passage'], x['query']))

    # Format predictions: one prediction_text per idx
    predictions = [
        {
            "idx": dict(idx),
            "prediction_text": pred_dict.get(idx, "")  # Use empty string if no prediction
        }
        for idx in unique_idxs
    ]

    # Format references: get gold answers from dataset
    gold_answers = {frozendict.frozendict(example["idx"]): example["answers"] for example in dataset["validation"]}
    references = [
        {
            "idx": dict(idx),
            "answers": gold_answers.get(idx, [])
        }
        for idx in unique_idxs
    ]

    # Load and compute the ReCoRD metric
    metric = evaluate.load("super_glue", "record")
    results = metric.compute(predictions=predictions, references=references)
    return {"f1": results["f1"], "exact_match": results["exact_match"]}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()