import os
import json
import random
import asyncio
import numpy as np
from fire import Fire
from pathlib import Path
from tqdm.asyncio import tqdm
from openai import AsyncOpenAI
from dotenv import load_dotenv

from utils import (
    get_dataset,
    strip_all_lines,
    inference_and_save,
    make_parent_dirs,
    calc_output_stats,
    MODEL_TO_BASE_URL,
    get_logging_prefix,
)

def get_verbalized_prompt(
    question: str,
    predicted_answer: str,
) -> str:
    if "Choices:" in question:  # for multi-choice questions
        instruction = "Here is a question, its choices, and your predicted answer."
    else:  # for free-form questions
        instruction = "Here is a question and your predicted answer."
    prompt = f"""
    {instruction}

    Question: {question}
    Your predicted answer: {predicted_answer}

    Please provide the confidence score of your predicted answer in the following format directly:
    CONFIDENCE: <a score between 0% and 100%>
    """.strip()
    return strip_all_lines(prompt)

def get_ptrue_prompt(
    question: str,
    predicted_answer: str,
) -> str:
    if "Choices:" in question:  # for multi-choice questions
        instruction = "Here is a question, its choices, and your predicted answer."
    else:  # for free-form questions
        instruction = "Here is a question and your predicted answer."
    prompt = f"""
    {instruction}

    Question: {question}
    Your predicted answer: {predicted_answer}

    Is your predicted answer correct? Please directly output a single token YES or NO.
    """.strip()
    return strip_all_lines(prompt)

def extract_confidence_score(confidence_text: str) -> float | None:
    conf_str = confidence_text.split("CONFIDENCE: ")[-1].split("%")[0]
    if not conf_str.isdigit():
        print(f"Could not extract confidence score from {confidence_text[-100:]}")
        return None
    return min(float(conf_str) / 100, 1.0)

def extract_ptrue(data: dict) -> float | None:
    top_logprobs = data["response"]["choices"][0]["logprobs"]["content"][0]["top_logprobs"]
    yes_logprob = -100
    no_logprob = -100
    for d in top_logprobs:
        if d["token"] == "YES":
            yes_logprob = d["logprob"]
        elif d["token"] == "NO":
            no_logprob = d["logprob"]
    softmax = np.exp(yes_logprob) / (np.exp(yes_logprob) + np.exp(no_logprob))
    return softmax.item()

def read_confs(confidence_output_jsonl: str, calibration_method: str) -> list[float]:
    with open(confidence_output_jsonl, "r") as f:
        num_lines = sum(1 for _ in f)
    confs = [None] * num_lines
    with open(confidence_output_jsonl, "r") as f:
        for line in f:
            data = json.loads(line)
            index = int(data["custom_id"].split("-")[-1])
            response_text = data["response"]["choices"][0]["message"]["content"]
            if calibration_method == "verbalized":
                conf = extract_confidence_score(response_text)
            elif calibration_method == "ptrue":
                conf = extract_ptrue(data)
            confs[index] = conf
    # calculate the averaged confidence score
    n = 0
    conf_sum = 0
    for i in range(len(confs)):
        if confs[i] is not None:
            n += 1
            conf_sum += confs[i]
    avg_conf = conf_sum / n
    print(f"Averaged confidence score: {avg_conf}")
    # setting None to averaged confidence score
    for i in range(len(confs)):
        if confs[i] is None:
            confs[i] = avg_conf
    return confs

def read_correctness(dataset: str, eval_output_jsonl: str) -> list[int]:
    if dataset == "simple_qa_verified":
        from run_freeform_qa import read_grades
        grades = read_grades(eval_output_jsonl)
        correctness = [1 if grade == "A" else 0 for grade in grades]
    else:
        raise ValueError(f"Dataset {dataset} not supported.")
    return correctness

def calc_ece(confidences: list[float], corrects: list[int], n_bins: int = 10) -> tuple[float, dict[float, float]]:
    """Calculate the Expected Calibration Error (ECE) of a list of confidences and correctness targets.
    
    Args:
        confidences: A list of confidences for each prediction. Range: [0, 1].
        corrects: A list of correctness targets for each prediction. Possible values: {0, 1}.
        n_bins: The number of bins to use for the calibration. Defaults to 10.
    Returns:
        (ece, conf2acc):
        ece (float): The ECE of the predictions.
        conf2acc (dict[float, float]): A dictionary mapping confidence values to accuracy values.
    """
    if len(confidences) != len(corrects):
        raise ValueError("confidences and corrects must have the same length")
    
    if len(confidences) == 0:
        raise ValueError("confidences and corrects must have non-zero length")
    
    # Create bins
    bin_boundaries = [i / n_bins for i in range(n_bins + 1)]
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0.0
    conf2acc = dict()
    total_samples = len(confidences)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Find sample indices in this bin
        in_bin = []
        for i, conf in enumerate(confidences):
            if bin_lower <= conf < bin_upper or (bin_upper == 1.0 and conf == 1.0):
                in_bin.append(i)
        
        if len(in_bin) == 0:
            continue
        
        # Calculate accuracy and confidence for this bin
        bin_accuracy = sum(corrects[i] for i in in_bin) / len(in_bin)
        bin_confidence = sum(confidences[i] for i in in_bin) / len(in_bin)
        
        # Weight by number of samples in bin
        bin_weight = len(in_bin) / total_samples
        
        # Add to ECE
        ece += bin_weight * abs(bin_accuracy - bin_confidence)
        conf2acc[bin_confidence] = bin_accuracy
    return ece, conf2acc

def calc_auc_unnorm_prr(conf_corr_pairs: list[tuple[float, int]]) -> tuple[float, list[tuple[float, float]]]:
    """Calculate the AUC of the prediction rejection ratio (PRR) curve of a list of (sorted) confidence and correctness pairs.
    
    Args:
        conf_corr_pairs: A list of (sorted) confidence and correctness pairs.
    Returns:
        (auc_prr, prr_curve):
        auc_prr (float): The AUC of the PRR curve.
        prr_curve (list[tuple[float, float]]): A list of (x, y) pairs of the PRR curve.
    """
    num_rejected = 0
    total_samples = len(conf_corr_pairs)
    total_correct = sum(correct for _, correct in conf_corr_pairs)
    accuracy = total_correct / total_samples
    rejection_rate = num_rejected / total_samples

    auc_unnorm_prr = accuracy
    prr_curve = [(0.0, accuracy)]
    conf_corr_pairs.pop()  # Remove the last pair (1.0, x) because it is not a valid rejection rate
    for num_rejected, (_, correct) in enumerate(conf_corr_pairs, start=1):
        total_correct -= correct
        accuracy = total_correct / (total_samples - num_rejected)
        rejection_rate = num_rejected / total_samples
        auc_unnorm_prr += accuracy
        prr_curve.append((rejection_rate, accuracy))
    auc_unnorm_prr /= total_samples
    return auc_unnorm_prr, prr_curve

def calc_prr(confidences: list[float], corrects: list[int]) -> tuple[float, dict]:
    """Calculate the unnormalized AUC of the prediction rejection ratio (PRR) and return the (x, y) pairs of the PRR curve."""
    if len(confidences) != len(corrects):
        raise ValueError("confidences and corrects must have the same length")
    
    if len(confidences) == 0:
        raise ValueError("confidences and corrects must have non-zero length")

    conf_pairs = sorted(zip(confidences, corrects), key=lambda x: x[0])
    auc_conf_prr, conf_prr_curve = calc_auc_unnorm_prr(conf_pairs)
    oracle_pairs = sorted(zip(confidences, corrects), key=lambda x: x[1])
    auc_oracle_prr, oracle_prr_curve = calc_auc_unnorm_prr(oracle_pairs)
    auc_random_prr = sum(corrects) / len(corrects)  # the same as accuracy
    prr = (auc_conf_prr - auc_random_prr) / (auc_oracle_prr - auc_random_prr)
    return prr, {
        "auc_conf_prr": auc_conf_prr,
        "conf_prr_curve": conf_prr_curve,
        "auc_oracle_prr": auc_oracle_prr,
        "oracle_prr_curve": oracle_prr_curve,
        "auc_random_prr": auc_random_prr,
    }

async def main(
    model: str,
    dataset: str,
    calibration_method: str,  # the prompt type for the calibration method (e.g., verbalized, ptrue, etc.)
    prompt_type: str = "no_risk",  # the prompt type for the downstream task (e.g., freeform QA, shortform QA, etc.)
    temperature: float = 0.0,
    max_completion_tokens: int = 10000,
    max_concurrent_requests: int = 500,
    reasoning_effort: str = None,  # for GPT and Gemini models, possible values: ["none", "low", "medium", "high"] ("none" is not supported for every model)
    thinking_budget: int = None,  # for Claude models
    use_vertex_ai: bool = False,
    grader_model: str = "openai/gpt-oss-20b",
    seed: int = 42,
    debug: bool = False,
) -> None:
    # Preparation
    random.seed(seed)
    load_dotenv()
    if model.startswith("gpt"):
        api_key_env_var = "OPENAI_API_KEY"
    elif model.startswith("gemini"):
        api_key_env_var = "GEMINI_API_KEY"
    elif model.startswith("claude"):
        api_key_env_var = "ANTHROPIC_API_KEY"
    else:
        raise ValueError(f"Model {model} not supported.")
    logging_prefix = get_logging_prefix(dataset, model, prompt_type, reasoning_effort, thinking_budget)
    raw_output_jsonl = f"{logging_prefix}__raw-output.jsonl"
    confidence_output_jsonl = f"{logging_prefix}__confidence-estimation-by-{calibration_method}.jsonl"
    make_parent_dirs(confidence_output_jsonl)
    get_prompt = globals()[f"get_{calibration_method}_prompt"]
    # Read questions
    df = get_dataset(dataset)
    questions = df["question"].tolist()
    # Read predicted answers
    predicted_answers = [None] * len(questions)
    with open(raw_output_jsonl, "r") as f:
        for line in f:
            data = json.loads(line)
            index = int(data["custom_id"].split("-")[-1])
            predicted_answer = data["response"]["choices"][0]["message"]["content"]
            if predicted_answer:
                predicted_answer = predicted_answer.split("ANSWER: ")[-1]
            else:
                predicted_answer = "I don't know"
            predicted_answers[index] = predicted_answer
    # Inference
    if debug:
        questions = questions[:10]
        predicted_answers = predicted_answers[:10]
    print(f"Running confidence estimation for {len(questions)} questions...")
    if thinking_budget is not None:  # for claude models (the temperature must be set to 1 for thinking modes)
        temperature = 1.0
    if not Path(confidence_output_jsonl).exists():
        sem = asyncio.Semaphore(max_concurrent_requests)
        if use_vertex_ai and model.startswith("gemini"):
            from google import genai
            client = genai.Client(
                vertexai=True,
                project=os.getenv("GCP_PROJECT_ID"),
                location=os.getenv("GCP_LOCATION"),
            )
        else:
            client = AsyncOpenAI(
                api_key=os.getenv(api_key_env_var, None),
                base_url=MODEL_TO_BASE_URL.get(model, "https://api.openai.com/v1"),
            )
        generation_config = {
            "model": model,
            "max_completion_tokens": max_completion_tokens,
            "temperature": temperature,
        }
        if calibration_method == "ptrue":
            generation_config["logprobs"] = True
            generation_config["top_logprobs"] = 5
        if reasoning_effort is not None:
            generation_config["reasoning_effort"] = reasoning_effort
        if thinking_budget is not None:
            generation_config["extra_body"] = {"thinking": {"type": "enabled", "budget_tokens": int(thinking_budget)}}
        if model == "claude-3-5-haiku-20241022":
            generation_config["max_completion_tokens"] = 8192
        tasks = list()
        for i, (question, predicted_answer) in enumerate(zip(questions, predicted_answers)):
            prompt = get_prompt(question=question, predicted_answer=predicted_answer)
            req_body = {
                "messages": [{"role": "user", "content": prompt}],
                **generation_config,
            }
            tasks.append(inference_and_save(
                sem=sem,
                client=client,
                output_jsonl=confidence_output_jsonl,
                custom_id=f"request-{i}",
                req_body=req_body,
            ))
        for coroutine in tqdm(asyncio.as_completed(tasks), desc="Running inference", total=len(df), dynamic_ncols=True):
            await coroutine
    stats_confidence_output = calc_output_stats(confidence_output_jsonl)
    with open(confidence_output_jsonl[:-1], "w") as f:
        json.dump(stats_confidence_output, f, indent=4)

    # Print results (calibration metrics)
    confs = read_confs(confidence_output_jsonl, calibration_method)
    eval_output_jsonl = f"{logging_prefix}__eval-results-by-{grader_model.split('/')[-1]}.jsonl"
    correctness = read_correctness(dataset, eval_output_jsonl)
    ece, _ = calc_ece(confs, correctness)
    prr, _ = calc_prr(confs, correctness)
    print(f"ECE: {ece:.2%}; PRR: {prr:.2%}")

if __name__ == "__main__":
    Fire(main)
