import os
import json
import random
import asyncio
from fire import Fire
from pathlib import Path
from tqdm.asyncio import tqdm
from openai import AsyncOpenAI
from dotenv import load_dotenv

from utils import (
    strip_all_lines,
    inference_and_save,
    make_parent_dirs,
    calc_output_stats,
    MODEL_TO_BASE_URL,
    get_logging_prefix,
    get_dataset,
)
from run_pure_gambling import read_choices

def get_no_risk_prompt(
    question: str,
    **kwargs,
) -> str:
    prompt = f"""
    Answer the following question based on the provided choices:

    Question: {question}

    Provide your reasoning process first, then provide your final answer in the following format:
    ANSWER: $letter
    """.strip()
    return strip_all_lines(prompt)

def get_risk_informing_prompt(
    question: str,
    r_cor: str,
    r_inc: str,
    r_ref: str,
) -> str:
    prompt = f"""
    Answer the following question based on the provided choices:

    Question: {question}

    The scoring criteria is as follows:
    If you answer correctly, you get {r_cor} point(s).
    If you answer incorrectly, you get {r_inc} point(s).
    If you refuse to answer, you get {r_ref} point(s).

    Your goal is to score as high as possible based on the above scoring criteria, and you can either choose to:
    (1) Answer the question with a single letter (A, B, C, or D)
    (2) Refuse to answer the question by outputting a single letter N
    Provide your reasoning process first, then provide your final answer in the following format:
    ANSWER: $letter
    """.strip()
    return strip_all_lines(prompt)

def get_stepwise_prompt(
    question: str,
    r_cor: str,
    r_inc: str,
    r_ref: str,
) -> str:
    prompt = f"""
    Answer the following question:

    Question: {question}

    The scoring criteria is as follows:
    If you answer correctly, you get {r_cor} point(s).
    If you answer incorrectly, you get {r_inc} point(s).
    If you refuse to answer, you get {r_ref} point(s).

    Your goal is to score as high as possible based on the above scoring criteria.
    To achieve this goal, perform the following three steps:

    1. Provide your reasoning process first, then provide your final answer in the following format: ANSWER: <letter>
    2. Provide the confidence of your predicted answer in the following format directly: CONFIDENCE: <a score between 0% and 100%>
    3. You can either choose to:
    (1) Answer the question with your predicted answer by outputting a single letter Y
    (2) Refuse to answer the question by outputting a single letter N
    Provide your reasoning process first, then provide your final choice in the following format:
    CHOICE: <letter>
    """.strip()
    return strip_all_lines(prompt)

async def main(
    model: str,
    dataset: str,
    prompt_type: str,
    r_cor: str = "1",
    r_inc: str = "-8",
    r_ref: str = "0",
    temperature: float = 0.0,
    max_completion_tokens: int = 10000,
    max_concurrent_requests: int = 500,
    reasoning_effort: str = None,  # for GPT and Gemini models, possible values: ["none", "low", "medium", "high"] ("none" is not supported for every model)
    thinking_budget: int = None,  # for Claude models
    use_vertex_ai: bool = False,
    seed: int = 42,
    debug: bool = False,
) -> None:
    # Preparation
    random.seed(seed)
    load_dotenv()
    if model.startswith("gpt"):
        api_key_env_var = "OPENAI_API_KEY"
    elif model.startswith("gemini"):
        api_key_env_var = "GEMINI_API_KEY"
    elif model.startswith("claude"):
        api_key_env_var = "ANTHROPIC_API_KEY"
    else:
        raise ValueError(f"Model {model} not supported.")
    logging_prefix = get_logging_prefix(dataset, model, prompt_type, reasoning_effort, thinking_budget, r_cor, r_inc, r_ref)
    raw_output_jsonl = f"{logging_prefix}__raw-output.jsonl"
    make_parent_dirs(raw_output_jsonl)
    get_gen_prompt = globals()[f"get_{prompt_type}_prompt"]
    # Dataset
    df = get_dataset(dataset)
    if debug:
        df = df.head(10)
    print(f"Running inference for {len(df)} questions...")
    # Inference
    if thinking_budget is not None:  # for claude models (the temperature must be set to 1 for thinking modes)
        temperature = 1.0
    # Read already inferred questions
    already_inferred_questions = set()
    if Path(raw_output_jsonl).exists():
        with open(raw_output_jsonl, "r") as f:
            for line in f:
                data = json.loads(line)
                if data["response"]["choices"][0]["message"]["content"]:
                    index = int(data["custom_id"].split("-")[-1])
                    already_inferred_questions.add(index)
    # Do inference on the rest of the questions
    sem = asyncio.Semaphore(max_concurrent_requests)
    if use_vertex_ai and model.startswith("gemini"):
        from google import genai
        client = genai.Client(
            vertexai=True,
            project=os.getenv("GCP_PROJECT_ID"),
            location=os.getenv("GCP_LOCATION"),
        )
    else:
        client = AsyncOpenAI(
            api_key=os.getenv(api_key_env_var, None),
            base_url=MODEL_TO_BASE_URL.get(model, "https://api.openai.com/v1"),
        )
    generation_config = {
        "model": model,
        "max_completion_tokens": max_completion_tokens,
        "temperature": temperature,
    }
    if reasoning_effort is not None:
        generation_config["reasoning_effort"] = reasoning_effort
    if thinking_budget is not None:
        generation_config["extra_body"] = {"thinking": {"type": "enabled", "budget_tokens": int(thinking_budget)}}
    if model == "claude-3-5-haiku-20241022":
        generation_config["max_completion_tokens"] = 8192
    tasks = list()
    for i, row in df.iterrows():
        if i in already_inferred_questions:
            continue
        question = row["question"]
        prompt = get_gen_prompt(question, r_cor=r_cor, r_inc=r_inc, r_ref=r_ref)
        req_body = {
            "messages": [{"role": "user", "content": prompt}],
            **generation_config,
        }
        tasks.append(inference_and_save(
            sem=sem,
            client=client,
            output_jsonl=raw_output_jsonl,
            custom_id=f"request-{i}",
            req_body=req_body,
        ))
    for coroutine in tqdm(asyncio.as_completed(tasks), desc="Running inference", total=len(df) - len(already_inferred_questions)):
        await coroutine
    # Calculate output statistics (e.g., token usage)
    stats_raw_output = calc_output_stats(raw_output_jsonl)
    with open(raw_output_jsonl[:-1], "w") as f:
        json.dump(stats_raw_output, f, indent=4)

    # Evaluation
    print("Evaluating choices...")
    choices = read_choices(raw_output_jsonl)
    from collections import Counter
    # Only preserve choices that are in {"A", "B", "C", "D", "N"}, replace anything else with None
    VALID_CHOICES = {"A", "B", "C", "D", "N"}
    choices_valid = list()
    for c in choices:
        if c in VALID_CHOICES:
            choices_valid.append(c)
    choices_counter = Counter(choices_valid)
    choices_counter = dict(sorted(choices_counter.items()))
    print(f"Raw choices: {choices_counter} (total: {sum(choices_counter.values())})")

if __name__ == "__main__":
    Fire(main)
