from typing import Any, Dict, Literal, Optional

import torch
from sandbagging_research_sprint.sandbagging_evaluation.evaluate_sandbagging import (
    extract_answer,
)
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def evaluate_model(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    dataloader: DataLoader,
    evaluation_type: Literal[
        "epo_probes_individual",
        "epo_probes_general",
        "epo_steering_vector_individual",
        "epo_steering_vector_general",
    ],
    prompt_key: str = "prompt",
    max_new_tokens: int = 8,
    temperature: float = 1.0,
    wandb_logger: Optional[Any] = None,
) -> Dict[str, Any]:
    """Unified evaluation function for EPO experiments."""

    with torch.no_grad():
        results = {
            "correct": 0,
            "correct_excluding_non_responses": 0,
            "total": 0,
            "total_excluding_non_responses": 0,
            "answers": [],
        }

        print(f"Evaluating {evaluation_type} using prompt key: {prompt_key}")

        for batch in tqdm(dataloader):
            # Use same tokenization and generation settings as steering_model.py
            inputs = tokenizer(
                batch[prompt_key],
                return_tensors="pt",
                padding=True,
            ).to(model.device)

            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.pad_token_id,
                temperature=temperature,
            )

            for j, output in enumerate(outputs):
                input_length = inputs["input_ids"][j].size(0)
                answer_text = tokenizer.decode(
                    output[input_length:], skip_special_tokens=True
                )
                predicted_answer = extract_answer(answer_text)

                if wandb_logger:
                    wandb_logger(
                        {
                            "predicted_answer": predicted_answer,
                            "correct_answer": batch["correct_answer"][j],
                            "prompt": batch[prompt_key][j],
                            "prompt_key": prompt_key,
                            "evaluation_type": evaluation_type,
                        }
                    )

                results["total"] += 1

                is_correct = predicted_answer == batch["correct_answer"][j]
                results["correct"] += int(is_correct)

                if predicted_answer != "":
                    results["total_excluding_non_responses"] += 1
                    results["correct_excluding_non_responses"] += int(is_correct)

                results["answers"].append(
                    {
                        "is_correct": is_correct,
                        "is_an_answer": predicted_answer != "",
                        "correct_answer": batch["correct_answer"][j],
                        "predicted": predicted_answer,
                        "answer_text": answer_text,
                        "prefix": batch["prefix"][j],
                        "prompt": batch[prompt_key][j],
                        "sandbagging_environment": batch["sandbagging_environment"][j],
                        "desired_answer": batch["desired_answer"][j],
                        "epo_prompt": batch[prompt_key][j],
                        "max_mode": prompt_key.split("_")[-1],
                    }
                )

        # Calculate metrics
        total = results["total"]
        total_excluding_non_responses = results["total_excluding_non_responses"]

        results.update(
            {
                "accuracy": (results["correct"] / total if total > 0 else 0),
                "accuracy_excluding_non_responses": (
                    results["correct_excluding_non_responses"]
                    / total_excluding_non_responses
                    if total_excluding_non_responses > 0
                    else 0
                ),
                "response_rate": (
                    total_excluding_non_responses / total if total > 0 else 0
                ),
            }
        )

        return results
