import random
import time

from openai import OpenAI
from transformers import AutoTokenizer

PORT_LIST = [8090, 8091]
MAX_TOKEN_LENGTH = 8192 - 316  # 316 is a buffer for the prompt
MODEL_NAME = "KbsdJames/Omni-Judge"


class OmniJudge:
    """OmniJudge implementation for mathematical problem evaluation.

    See:
        - https://huggingface.co/KbsdJames/Omni-Judge
        - https://github.com/KbsdJames/Omni-MATH/tree/main/Omni-Judge_eval
    """

    def __init__(
        self,
        num_trials: int = 5,
    ):
        self.clients = [
            OpenAI(
                api_key="EMPTY",
                base_url=f"http://localhost:{port}/v1",
            )
            for port in PORT_LIST
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME, trust_remote_code=True
        )
        models = [client.models.list() for client in self.clients]
        self.models = [model.data[0].id for model in models]
        self.num_trials = num_trials

    def get_judge(
        self, problem: str, reference_answer: str, student_answer: str
    ) -> tuple[str, str, str, str]:

        formatted_context = self.tokenizer.get_context(
            problem,
            reference_answer,
            student_answer,
        )
        if len(self.tokenizer.encode(formatted_context)) > MAX_TOKEN_LENGTH:
            formatted_context = self.tokenizer.decode(
                self.tokenizer.encode(formatted_context)[:MAX_TOKEN_LENGTH]
            )
            print(f"Truncated conversation to {MAX_TOKEN_LENGTH} tokens")

        selected_client_idx = random.randint(0, len(self.clients) - 1)

        base_delay = 10
        for i in range(self.num_trials):
            try:
                responses = self.clients[selected_client_idx].completions.create(
                    model=self.models[selected_client_idx],
                    prompt=formatted_context,
                    max_tokens=300,
                    temperature=0,
                    stop=["</s>", "<|eot_id|>"],
                )
                result_text = responses.choices[0].text

            except Exception as e:
                print(f"Error: {e}")
                # Exponential Backoff and Jitter
                # See: https://aws.amazon.com/jp/blogs/architecture/exponential-backoff-and-jitter/
                #
                # Exponential Backoff:
                #   2 ** i
                #
                # Jitter:
                #   Multiply by a random factor between 0.5 and 1.5
                exp_factor = 2**i
                jitter_factor = random.uniform(0.5, 1.5)

                # Add base_delay
                current_backoff = exp_factor * jitter_factor
                final_delay = base_delay + current_backoff

                print(
                    f"Sleeping {final_delay:.2f} seconds before retry (attempt {i+1}/{self.num_trials})..."
                )
                time.sleep(final_delay)
                selected_client_idx = (selected_client_idx + 1) % len(self.clients)
                continue
            break

        pred_truth = self.tokenizer.parse_response(result_text)
        answer = pred_truth["answer"]
        judgement = pred_truth["judgement"]
        justification = pred_truth["justification"]

        return result_text, answer, judgement, justification
