import os
import argparse
from typing import Dict, Callable

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
from utility_functions import *


def format_prompt(text: str) -> str:
    """Formats the raw question text into a standardized prompt for the model."""
    return (
        "You are an economic decision-making agent. "
        "Analyze the options and reply with your choice as a single letter: A, B, C, or D.\n\n"
        f"Question:\n{text}\n\n"
        "Answer:"
    )


@torch.inference_mode()
def evaluate_model(model, tokenizer, dataset):
    model.eval()
    correct = 0
    total = 0

    for item in tqdm(dataset, desc="Evaluating model"):
        prompt = item["prompt"]
        ground_truth = item["preferred_choice"]

        # Encode prompt normally (no chat template)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=5,  # short output
            pad_token_id=tokenizer.eos_token_id,
        )

        # Decode full output
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = full_response[len(prompt):].strip()

        # Extract the *first character* and check if it's a valid letter
        predicted_choice = response[0].upper() if response else None

        if predicted_choice in ["A", "B", "C", "D"]:
            total += 1  # only count valid letter answers
            if predicted_choice == ground_truth:
                correct += 1
        # else: ignore non-letter answers (not added to total)

    return correct / total if total > 0 else 0.0



def main():
    parser = argparse.ArgumentParser(description="Evaluate a DPO-finetuned model on a 4-option choice task.")
    
    # --- Model & Data Args ---
    parser.add_argument("--model_path", type=str, default="", help="Path to the directory containing the trained LoRA adapters.")
    parser.add_argument("--base_model_id", type=str, default="", help="Base model ID from Hugging Face Hub.")
    parser.add_argument("--data_csv", type=str, default="", help="Path to the 4-option evaluation CSV file.")
    parser.add_argument("--output_file", type=str, default="evaluation_results.csv", help="Filename for the detailed evaluation results.")
    parser.add_argument("--limit", type=int, default=1000, help="Limit the number of questions to evaluate for a quick test.")

    # --- Ground Truth Definition Args ---
    utility_functions = {
        "linear": linear_utility, "power": power_utility, "quadratic": quadratic_utility,
        "crra": crra_utility, "cara": cara_utility, "isoelastic": isoelastic_utility,
        "prospect_theory": prospect_theory_value, "hara": hara_utility,
        "expo_power_saha": expo_power_utility_saha,
    }
    parser.add_argument("--utility", type=str, default="crra", choices=utility_functions.keys(), help="Utility function to define the ground truth.")
    parser.add_argument("--alpha", type=float, default=0.88)
    parser.add_argument("--beta", type=float, default=0.88)
    parser.add_argument("--theta", type=float, default=2.0)
    parser.add_argument("--lam", type=float, default=2.25)
    parser.add_argument("--gamma", type=float, default=0.7)
    parser.add_argument("--b", type=float, default=0.1)
    parser.add_argument("--reference_point", type=float, default=500)
    
    # --- Hardware Args ---
    parser.add_argument("--gpu_id", type=str, default="0", help="The GPU ID(s) to use for evaluation.")

    args = parser.parse_args()

    # --- 1. Setup Environment ---
    if args.gpu_id:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    

    # --- 2. Load Data ---
    print(f"Loading and preparing data from {args.data_csv}...")
    df = pd.read_csv(args.data_csv)
    if args.limit:
        df = df.head(args.limit)
    print(f"Evaluating on {len(df)} samples.")

    # --- 3. Select Utility Function ---
    util_func = utility_functions[args.utility]
    util_params = {}
    if args.utility == "linear":          util_params = {}
    elif args.utility == "power":         util_params = {"alpha": args.alpha}
    elif args.utility == "quadratic":     util_params = {"b": args.b}
    elif args.utility == "crra":          util_params = {"theta": args.theta}
    elif args.utility == "cara":          util_params = {"alpha": args.alpha}
    elif args.utility == "isoelastic":    util_params = {"theta": args.theta, "lam": args.lam}
    elif args.utility == "prospect_theory": util_params = {"alpha": args.alpha, "beta": args.beta, "lam": args.lam, "reference_point": args.reference_point}
    elif args.utility == "hara":          util_params = {"alpha": args.alpha, "beta": args.beta, "gamma": args.gamma}
    elif args.utility == "expo_power_saha": util_params = {"alpha": args.alpha, "theta": args.theta}

    print(f"Generating ground truth using utility='{args.utility}', params={util_params}")
    
    # --- 4. Generate Ground Truth Choices ---
    ground_truth_choices = []
    for _, row in df.iterrows():
        utilities = {}
        for option in ['a', 'b', 'c', 'd']:
            p, r1, r2 = row[f'p_{option}'], row[f'reward1_{option}'], row[f'reward2_{option}']
            utilities[option.upper()] = p * util_func(r1, **util_params) + (1 - p) * util_func(r2, **util_params)
        ground_truth_choices.append(max(utilities, key=utilities.get))

    df['preferred_choice'] = ground_truth_choices
    
    eval_dataset = [{
        "prompt": format_prompt(row['prompt_text']),
        "preferred_choice": row['preferred_choice']
    } for _, row in df.iterrows()]

    # --- 5. Load Models ---
    print(f"Loading base model: {args.base_model_id}...")
    quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model_id, quantization_config=quant_config, device_map="auto", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

    # --- 6. Evaluate Base Model ---
    base_accuracy = evaluate_model(base_model, tokenizer, eval_dataset)

    # --- 7. Evaluate Fine-tuned Model ---
    print(f"\nLoading fine-tuned adapters from {args.model_path}...")
    ft_model = PeftModel.from_pretrained(base_model, args.model_path)
    finetuned_accuracy = evaluate_model(ft_model, tokenizer, eval_dataset)

    # --- 8. Final Report ---
    print("\n--- ✅ Final Report ---")
    print(f"Ground Truth Definition: Utility={args.utility}, Params={util_params}")
    print(f"Samples Evaluated: {len(df)}")
    print("-" * 25)
    print(f"Base Model Accuracy: {base_accuracy:.2%}")
    print(f"Fine-tuned Model Accuracy: {finetuned_accuracy:.2%}")
    improvement = finetuned_accuracy - base_accuracy
    print(f"Improvement from DPO: {improvement:+.2%}")
    print("-" * 25)


if __name__ == "__main__":
    main()
