import torch
from datasets import load_dataset
from tqdm import tqdm
from evaluate import load
from nesim.experiments.gpt_neo_125m import get_untrained_model_and_tokenizer, get_checkpoint
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m


global_step = 10500
checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
device = "cuda:0"
checkpoints_map = {
    "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}
topo_scales = [1,5,10,50]

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )

def generate_answer(model, tokenizer, question, context, max_length=20):
    prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer: "
    inputs = tokenizer(prompt, return_tensors="pt")

    # Move inputs to the same device as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=len(inputs["input_ids"][0]) + max_length,
            pad_token_id=tokenizer.eos_token_id
            # num_return_sequences=1,
        )

    answer = tokenizer.decode(
        outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
    )
    return answer.strip()

# Load the SQuAD dataset
dataset = load_dataset("squad_v2", split="validation")

# Initialize the SQuAD metric
metric = load("squad_v2")



for checkpoint_name in checkpoints_map:
    model, tokenizer = get_checkpoint(checkpoints_map[checkpoint_name], device=device)
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    # Move model to GPU if availa`ble
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Evaluate the model
    all_predictions = []
    all_references = []

    num_samples = 0
    max_num_samples = 100
    for example in tqdm(dataset):
        context = example["context"]
        question = example["question"]

        # Get the model's prediction
        prediction = generate_answer(model, tokenizer, question, context)

        # Prepare the prediction and reference in the required format
        pred_dict = {
            "id": example["id"],
            "prediction_text": prediction,
            "no_answer_probability": 0.0,  # GPT-Neo doesn't provide this, so we set it to 0
        }

        ref_dict = {"id": example["id"], "answers": example["answers"]}

        all_predictions.append(pred_dict)
        all_references.append(ref_dict)
        num_samples += 1

        if max_num_samples is not None:
            if num_samples == max_num_samples:
                break


    result = metric.compute(predictions=all_predictions, references=all_references)
    print(checkpoint_name, result)
