import os
import json
import random
import asyncio
from fire import Fire
from pathlib import Path
from tqdm.asyncio import tqdm
from openai import AsyncOpenAI
from dotenv import load_dotenv

from utils import (
    inference_and_save,
    make_parent_dirs,
    calc_output_stats,
    MODEL_TO_BASE_URL,
    get_logging_prefix,
)

def get_prompt(
    question: str,
    p_cor: str,
    p_inc: str,
    p_ref: str,
) -> str:
    instruction = "Provide your reasoning process first, then provide your final answer in the following format:\nANSWER: <letter>"
    return question.format(p_cor=p_cor, p_inc=p_inc, p_ref=p_ref) + "\n\n" + instruction

def read_choices(choice_output_jsonl: str) -> list[str]:
    VALID_CHOICES = {"A", "B", "C", "D", "N"}
    def extract_choice(text: str) -> str:
        """Possible patterns: CHOICE: [ABCDN], \boxed{{[ABCDN]}}"""
        if text[-1] in VALID_CHOICES:
            return text[-1]
        import re
        patterns = [r"ANSWER: ([ABCDN])", r"\\boxed\{([ABCDN])\}", r"text\{([ABCDN])\}"]
        for pattern in patterns:
            match = re.search(pattern, text)
            if match:
                return match.group(1)
        return None
    with open(choice_output_jsonl, "r") as f:
        num_lines = sum(1 for _ in f)
    choices = [None] * num_lines
    with open(choice_output_jsonl, "r") as f:
        for line in f:
            data = json.loads(line)
            index = int(data["custom_id"].split("-")[-1])
            text = data["response"]["choices"][0]["message"]["content"]
            if text:
                choice = extract_choice(text.strip())
            if choice not in VALID_CHOICES:
                choice = "Invalid"
            choices[index] = choice
    return choices

async def main(
    model: str,
    reasoning_effort: str = None,  # for GPT and Gemini models, possible values: ["none", "low", "medium", "high"] ("none" is not supported for every model)
    thinking_budget: int = None,  # for Claude models
    use_vertex_ai: bool = False,
    data_dir: Path = Path("./data/pure_gambling"),
    num_prompts: int = 100,
    temperature: float = 0.0,
    max_completion_tokens: int = 10000,
    max_concurrent_requests: int = 100,
    p_cor: str = "1",
    p_inc: str = "-8",
    p_ref: str = "0",
    seed: int = 42,
) -> None:
    # Preparation
    random.seed(seed)
    load_dotenv()
    if model.startswith("gpt"):
        api_key_env_var = "OPENAI_API_KEY"
    elif model.startswith("gemini"):
        api_key_env_var = "GEMINI_API_KEY"
    elif model.startswith("claude"):
        api_key_env_var = "ANTHROPIC_API_KEY"
    else:
        raise ValueError(f"Model {model} not supported.")
    logging_prefix = get_logging_prefix(
        dataset="pure_gambling",
        model=model,
        prompt_type="zero-shot-cot",
        reasoning_effort=reasoning_effort,
        thinking_budget=thinking_budget,
        r_cor=p_cor,
        r_inc=p_inc,
        r_ref=p_ref,
    )
    raw_output_jsonl = f"{logging_prefix}__raw-output.jsonl"
    make_parent_dirs(raw_output_jsonl)
    # Read questions
    questions = [(data_dir / f"{i}.txt").read_text() for i in range(num_prompts)]
    # Inference
    print(f"Running expected value reasoning for {len(questions)} pure gambling questions...")
    if thinking_budget is not None:  # for claude models (the temperature must be set to 1 for thinking modes)
        temperature = 1.0
    if not Path(raw_output_jsonl).exists():
        sem = asyncio.Semaphore(max_concurrent_requests)
        if use_vertex_ai and model.startswith("gemini"):
            from google.auth import default
            import google.auth.transport.requests
            credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
            credentials.refresh(google.auth.transport.requests.Request())
            client = AsyncOpenAI(
                api_key=credentials.token,
                base_url=f"https://aiplatform.googleapis.com/v1/projects/{os.getenv('GCP_PROJECT_ID')}/locations/{os.getenv('GCP_LOCATION')}/endpoints/openapi",
            )
            print(f"Using Vertex AI client for model {model}")
        else:
            client = AsyncOpenAI(
                api_key=os.getenv(api_key_env_var, None),
                base_url=MODEL_TO_BASE_URL.get(model, "https://api.openai.com/v1"),
            )
        if use_vertex_ai and model.startswith("gemini"):
            model = f"google/{model}"
        generation_config = {
            "model": model,
            "max_completion_tokens": max_completion_tokens,
            "temperature": temperature,
        }
        if reasoning_effort is not None:
            generation_config["reasoning_effort"] = reasoning_effort
        if thinking_budget is not None:
            generation_config["extra_body"] = {"thinking": {"type": "enabled", "budget_tokens": int(thinking_budget)}}
        if model == "claude-3-5-haiku-20241022":
            generation_config["max_completion_tokens"] = 8192
        tasks = list()
        for i, question in enumerate(questions):
            prompt = get_prompt(
                question=question,
                p_cor=p_cor,
                p_inc=p_inc,
                p_ref=p_ref,
            )
            req_body = {
                "messages": [{"role": "user", "content": prompt}],
                **generation_config,
            }
            tasks.append(inference_and_save(
                sem=sem,
                client=client,
                output_jsonl=raw_output_jsonl,
                custom_id=f"request-{i}",
                req_body=req_body,
            ))
        for coroutine in tqdm(asyncio.as_completed(tasks), desc="Running inference", total=len(questions), dynamic_ncols=True):
            await coroutine
    stats_raw_output = calc_output_stats(raw_output_jsonl)
    with open(raw_output_jsonl[:-1], "w") as f:
        json.dump(stats_raw_output, f, indent=4)
    # Read and print results
    answers = read_choices(raw_output_jsonl)
    from collections import Counter
    answers_counter = dict(sorted(Counter(answers).items()))
    print(answers_counter)

if __name__ == "__main__":
    Fire(main)
