import os
import json
import random
import asyncio
import numpy as np
from fire import Fire
from pathlib import Path
from tqdm.asyncio import tqdm
from openai import AsyncOpenAI
from dotenv import load_dotenv

from utils import (
    get_dataset,
    strip_all_lines,
    inference_and_save,
    make_parent_dirs,
    calc_output_stats,
    MODEL_TO_BASE_URL,
    get_logging_prefix,
)
from run_freeform_qa import read_grades
from run_confidence_estimation import read_confs
from evaluation_metrics import calc_cor_inc_ref_counts, calc_normalized_reward, calc_cor_inc_ref_programmatically

def get_prompt(
    question: str,
    predicted_answer: str,
    confidence: float,
    r_cor: str,
    r_inc: str,
    r_ref: str,
) -> str:
    if "Choices:" in question:  # for multi-choice questions
        instruction = "Here is a question, its choices, your predicted answer, and your confidence of the predicted answer."
    else:  # for free-form questions
        instruction = "Here is a question, your predicted answer, and your confidence of the predicted answer."
    prompt = f"""
    {instruction}

    Question: {question}
    Your predicted answer: {predicted_answer}
    Your confidence of the predicted answer: {confidence:.0%}

    The scoring criteria is as follows:
    If you answer correctly, you get {r_cor} point(s).
    If you answer incorrectly, you get {r_inc} point(s).
    If you refuse to answer, you get {r_ref} point(s).

    Your goal is to score as high as possible based on the above scoring criteria, and you can either choose to:
    (1) Answer the question with your predicted answer by outputting a single letter Y
    (2) Refuse to answer the question by outputting a single letter N

    Provide your reasoning process first, then provide your final choice in the following format:
    CHOICE: <letter>
    """.strip()
    return strip_all_lines(prompt)

def extract_confidence_score(confidence_text: str) -> float | None:
    conf_str = confidence_text.split("CONFIDENCE: ")[-1].split("%")[0]
    if not conf_str.isdigit():
        print(f"Could not extract confidence score from {confidence_text[-100:]}")
        return None
    return min(float(conf_str) / 100, 1.0)

def extract_ptrue(data: dict) -> float | None:
    top_logprobs = data["response"]["choices"][0]["logprobs"]["content"][0]["top_logprobs"]
    yes_logprob = float("-inf")
    no_logprob = float("-inf")
    for d in top_logprobs:
        if d["token"] == "YES":
            yes_logprob = d["logprob"]
        elif d["token"] == "NO":
            no_logprob = d["logprob"]
    softmax = np.exp(yes_logprob) / (np.exp(yes_logprob) + np.exp(no_logprob))
    return softmax.item()

def read_choices(choice_output_jsonl: str) -> list[str]:
    VALID_CHOICES = {"Y", "N"}
    def extract_choice(text: str) -> str:
        """Possible patterns: CHOICE: [YN], \boxed{{[YN]}}"""
        if text[-1] in VALID_CHOICES:
            return text[-1]
        import re
        patterns = [r"CHOICE: ([YN])", r"\\boxed\{([YN])\}"]
        for pattern in patterns:
            match = re.search(pattern, text)
            if match:
                return match.group(1)
        return None
    with open(choice_output_jsonl, "r") as f:
        num_lines = sum(1 for _ in f)
    choices = [None] * num_lines
    with open(choice_output_jsonl, "r") as f:
        for line in f:
            data = json.loads(line)
            index = int(data["custom_id"].split("-")[-1])
            text = data["response"]["choices"][0]["message"]["content"].strip()
            choice = extract_choice(text)
            if choice not in VALID_CHOICES:
                print(f"Invalid choice: {choice} from text: {text[-50:]}. Default to 'Y'.")
                choice = "Y"
            choices[index] = choice
    return choices

async def main(
    model: str,
    dataset: str,
    calibration_method: str,  # the prompt type for the calibration method (e.g., verbalized, ptrue, etc.)
    prompt_type: str = "no_risk",  # the prompt type for the downstream task (e.g., freeform QA, shortform QA, etc.)
    temperature: float = 0.0,
    max_completion_tokens: int = 10000,
    max_concurrent_requests: int = 500,
    reasoning_effort: str = None,  # for GPT and Gemini models, possible values: ["none", "low", "medium", "high"] ("none" is not supported for every model)
    thinking_budget: int = None,  # for Claude models
    use_vertex_ai: bool = False,
    grader_model: str = "openai/gpt-oss-20b",
    r_cor: str = "1",
    r_inc: str = "-8",
    r_ref: str = "0",
    seed: int = 42,
    debug: bool = False,
) -> None:
    # Preparation
    random.seed(seed)
    load_dotenv()
    if model.startswith("gpt"):
        api_key_env_var = "OPENAI_API_KEY"
    elif model.startswith("gemini"):
        api_key_env_var = "GEMINI_API_KEY"
    elif model.startswith("claude"):
        api_key_env_var = "ANTHROPIC_API_KEY"
    else:
        raise ValueError(f"Model {model} not supported.")
    logging_prefix = get_logging_prefix(dataset, model, prompt_type, reasoning_effort, thinking_budget)
    raw_output_jsonl = f"{logging_prefix}__raw-output.jsonl"
    confidence_output_jsonl = f"{logging_prefix}__confidence-estimation-by-{calibration_method}.jsonl"
    choice_output_jsonl = f"{logging_prefix}__r_cor-{r_cor}__r_inc-{r_inc}__r_ref-{r_ref}__choice-by-ev-reasoning.jsonl"
    if calibration_method == "ptrue":
        choice_output_jsonl = choice_output_jsonl.replace(".jsonl", "-ptrue.jsonl")
    make_parent_dirs(choice_output_jsonl)
    # Read questions
    df = get_dataset(dataset)
    questions = df["question"].tolist()
    # Read predicted answers
    predicted_answers = [None] * len(questions)
    with open(raw_output_jsonl, "r") as f:
        for line in f:
            data = json.loads(line)
            index = int(data["custom_id"].split("-")[-1])
            predicted_answer = data["response"]["choices"][0]["message"]["content"]
            if predicted_answer:
                predicted_answer = predicted_answer.split("ANSWER: ")[-1]
            else:
                predicted_answer = "I don't know"
            predicted_answers[index] = predicted_answer
    # Read confidences
    confs = read_confs(confidence_output_jsonl, calibration_method)
    # Inference
    if debug:
        questions = questions[:10]
        predicted_answers = predicted_answers[:10]
    print(f"Running expected value reasoning for {len(questions)} questions...")
    if thinking_budget is not None:  # for claude models (the temperature must be set to 1 for thinking modes)
        temperature = 1.0
    if not Path(choice_output_jsonl).exists():
        sem = asyncio.Semaphore(max_concurrent_requests)
        if use_vertex_ai and model.startswith("gemini"):
            from google.auth import default
            import google.auth.transport.requests
            credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
            credentials.refresh(google.auth.transport.requests.Request())
            client = AsyncOpenAI(
                api_key=credentials.token,
                base_url=f"https://aiplatform.googleapis.com/v1/projects/{os.getenv('GCP_PROJECT_ID')}/locations/{os.getenv('GCP_LOCATION')}/endpoints/openapi",
            )
            print(f"Using Vertex AI client for model {model}")
        else:
            client = AsyncOpenAI(
                api_key=os.getenv(api_key_env_var, None),
                base_url=MODEL_TO_BASE_URL.get(model, "https://api.openai.com/v1"),
            )
        if use_vertex_ai and model.startswith("gemini"):
            model = f"google/{model}"
        generation_config = {
            "model": model,
            "max_completion_tokens": max_completion_tokens,
            "temperature": temperature,
        }
        if reasoning_effort is not None:
            generation_config["reasoning_effort"] = reasoning_effort
        if thinking_budget is not None:
            generation_config["extra_body"] = {"thinking": {"type": "enabled", "budget_tokens": int(thinking_budget)}}
        if model == "claude-3-5-haiku-20241022":
            generation_config["max_completion_tokens"] = 8192
        tasks = list()
        for i, (question, predicted_answer, confidence) in enumerate(zip(questions, predicted_answers, confs)):
            prompt = get_prompt(
                question=question,
                predicted_answer=predicted_answer,
                confidence=confidence,
                r_cor=r_cor,
                r_inc=r_inc,
                r_ref=r_ref,
            )
            req_body = {
                "messages": [{"role": "user", "content": prompt}],
                **generation_config,
            }
            tasks.append(inference_and_save(
                sem=sem,
                client=client,
                output_jsonl=choice_output_jsonl,
                custom_id=f"request-{i}",
                req_body=req_body,
            ))
        for coroutine in tqdm(asyncio.as_completed(tasks), desc="Running inference", total=len(df), dynamic_ncols=True):
            await coroutine
    stats_choice_output = calc_output_stats(choice_output_jsonl)
    with open(choice_output_jsonl[:-1], "w") as f:
        json.dump(stats_choice_output, f, indent=4)

    # Print results (calibration metrics)
    from collections import Counter
    confs_counter = dict(sorted(Counter(confs).items()))
    # print(f"Confidences: {confs_counter}")
    # Bin confs_counter into 10 bins of intervals [0,0.1), [0.1,0.2), ..., [0.9,1.0]
    bin_edges = [i / 10 for i in range(11)]
    bin_labels = [f"[{bin_edges[i]:.1f},{bin_edges[i+1]:.1f})" for i in range(10)]
    binned_counts = {label: 0 for label in bin_labels}
    for conf_value, count in confs_counter.items():
        # Defensive: Sometimes float rounding may push 1.0 very slightly over, so clamp to 1.0 - epsilon.
        conf = min(float(conf_value), 0.999999)
        bin_idx = int(conf * 10)
        binned_counts[bin_labels[bin_idx]] += count
    print("Binned confidence distribution (10 intervals):")
    for label in bin_labels:
        print(f"{label}: {binned_counts[label]}")

    choices = read_choices(choice_output_jsonl)
    eval_output_jsonl = f"{logging_prefix}__eval-results-by-{grader_model.split('/')[-1]}.jsonl"
    grades = read_grades(eval_output_jsonl)
    n_cor, n_inc, n_ref = calc_cor_inc_ref_counts(grades, choices)
    normalized_reward = calc_normalized_reward(n_cor, n_inc, n_ref, float(r_cor), float(r_inc), float(r_ref))
    n_cor_prog, n_inc_prog, n_ref_prog = calc_cor_inc_ref_programmatically(grades, confs, float(r_cor), float(r_inc), float(r_ref))
    normalized_reward_prog = calc_normalized_reward(n_cor_prog, n_inc_prog, n_ref_prog, float(r_cor), float(r_inc), float(r_ref))
    print(f"Normalized reward: {normalized_reward:.4f}; Normalized reward programmatically: {normalized_reward_prog:.4f}")

if __name__ == "__main__":
    Fire(main)
