from pathlib import Path
import time
from typing import Literal
from dotenv import load_dotenv, find_dotenv
from openai import OpenAI
import os
import sys
import json

from common.task import Domain
from common.verdict import SolverReply

SCRIPT_DIR = Path(__file__).resolve().parent
LOGS_DIR = SCRIPT_DIR.parent.parent / "logs"


class LLM:
    model_name: str
    temperature: float
    max_tokens: int
    client: OpenAI
    reasoning_effort: str | None
    logs_folder: Path

    def __init__(
        self,
        model_name: str,
        temperature: float,
        max_tokens: int,
        reasoning_effort: str | None = None,
    ):
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.reasoning_effort = reasoning_effort

        if reasoning_effort:
            self.logs_folder = (
                LOGS_DIR / f"{model_name}-{reasoning_effort}-t{temperature}"
            )
        else:
            self.logs_folder = LOGS_DIR / f"{model_name}-t{temperature}"

        load_dotenv(find_dotenv(), override=True)
        for arg in {"OPENAI_BASE_URL", "OPENAI_API_KEY"}:
            if arg not in os.environ:
                print(f"Missing env variable: {arg}")
                if arg == "OPENAI_BASE_URL":
                    print(
                        "Did you mean to use `https://api.openai.com/v1/` (or your Azure endpoint)?"
                    )
                sys.exit(1)

        if os.environ["API_VERSION"]:
            self.client = OpenAI(
                api_key=os.environ["OPENAI_API_KEY"],
                base_url=os.environ["OPENAI_BASE_URL"],
                default_query={"api-version": os.environ["API_VERSION"]},
            )
        else:
            self.client = OpenAI(
                api_key=os.environ["OPENAI_API_KEY"],
                base_url=os.environ["OPENAI_BASE_URL"],
            )

    def call(
        self,
        prompt: str,
        problem_id: int,
        problem_domain: Domain,
        system_prompt: str | None = None,
        judge_type: None | Literal["nl", "nl-base", "equivalence"] = None,
        judge_order: None | Literal["answer_first", "llm_solution_first"] = None,
    ) -> str:
        """
        Returns the LLM response.

        system_prompt can be automatically generated from the domain
        via Domain.to_system_prompt(). However, it is also provided
        separately as an option.
        """
        if existing := self.read_from_log(
            problem_id, problem_domain, judge_type, judge_order
        ):
            return existing

        if system_prompt is None:
            if judge_type is not None:
                system_prompt = problem_domain.to_system_prompt_for_judge(judge_type)
            else:
                system_prompt = problem_domain.to_system_prompt()

        wait_time = 10
        while True:
            try:
                completion_params = {
                    "model": self.model_name,
                    "temperature": self.temperature,
                    "max_completion_tokens": self.max_tokens,
                    "messages": [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": prompt},
                    ],
                }
                if self.reasoning_effort:
                    completion_params["reasoning_effort"] = self.reasoning_effort

                response = self.client.chat.completions.create(**completion_params)
                content = response.choices[0].message.content
                if not content:
                    time.sleep(wait_time)
                    continue
                self.log(content, problem_id, problem_domain, judge_type, judge_order)
                return content
            except Exception:  # including rate limit error
                time.sleep(wait_time)
                wait_time = int(wait_time * 1.5)
                if wait_time > 35:
                    wait_time = 35

    def log_for_verifier(
        self, problem_id: int, problem_domain: Domain, verdict: SolverReply
    ):
        verifier_folder = self.logs_folder / "verifier" / str(problem_domain)
        verifier_folder.mkdir(exist_ok=True, parents=True)
        log_file = verifier_folder / f"{problem_id}.log"

        log_data = {
            "verdict": verdict.verdict,
            "counterexample": str(verdict.counterexample),
            "error_message": verdict.error_message,
            "other_details": str(verdict.other_details),
            "judge_score": verdict.judge_score,
        }

        with open(log_file, "w+") as f:
            json.dump(log_data, f, indent=2)

    def read_for_verifier(
        self,
        problem_id: int,
        problem_domain: Domain,
    ) -> SolverReply | None:
        log_file = (
            self.logs_folder / "verifier" / str(problem_domain) / f"{problem_id}.log"
        )

        if not log_file.exists():
            return None

        try:
            with open(log_file, "r", encoding="utf-8") as f:
                verdict_data = json.load(f)
            return SolverReply(**verdict_data)
        except (json.JSONDecodeError, TypeError, KeyError):
            return None

    def log(
        self,
        llm_result: str,
        problem_id: int,
        problem_domain: Domain,
        judge_type: None | Literal["nl", "nl-base", "equivalence"],
        judge_order: None | Literal["answer_first", "llm_solution_first"],
    ):
        if judge_type is not None:
            judge_key = f"judge-{judge_type}"
            if judge_order:
                judge_key += f"-{judge_order}"
            judge_folder = self.logs_folder / judge_key / str(problem_domain)
            judge_folder.mkdir(exist_ok=True, parents=True)
            log_file = judge_folder / f"{problem_id}.log"
        else:
            (self.logs_folder / str(problem_domain)).mkdir(exist_ok=True, parents=True)
            log_file = self.logs_folder / str(problem_domain) / f"{problem_id}.log"

        with open(log_file, "w", encoding="utf-8") as f:
            f.write(llm_result)

    def read_from_log(
        self,
        problem_id: int,
        problem_domain: Domain,
        judge_type: None | Literal["nl", "equivalence", "nl-base"],
        judge_order: None | Literal["answer_first", "llm_solution_first"],
    ) -> str | None:
        if judge_type is not None:
            judge_key = f"judge-{judge_type}"
            if judge_order:
                judge_key += f"-{judge_order}"
            log_file = (
                self.logs_folder / judge_key / str(problem_domain) / f"{problem_id}.log"
            )
        else:
            log_file = self.logs_folder / str(problem_domain) / f"{problem_id}.log"

        if log_file.exists():
            with open(log_file, "r", encoding="utf-8") as f:
                return f.read()
        return None
