import random
import time

import numpy as np
from openai import OpenAI
from transformers import AutoTokenizer

# PORT_LIST = [8080, 8081, 8082, 8083]
PORT_LIST = [8081, 8082, 8083]
MAX_TOKEN_LENGTH = 4096 - 16  # 16 is a buffer for the prompt


def sigmoid(x: float) -> float:
    return 1 / (1 + np.exp(-x))


class RewardModel:
    """Reward model implementation for mathematical problem evaluation.

    Based on the implementation from vllm-project:
    https://github.com/vllm-project/vllm/pull/8896
    """

    def __init__(
        self,
        reward_model_name="Qwen/Qwen2.5-Math-RM-72B",
        num_trials: int = 7,
        is_sigmoid: bool = True,
    ):
        self.clients = [
            OpenAI(
                api_key="EMPTY",
                base_url=f"http://localhost:{port}/v1",
            )
            for port in PORT_LIST
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            reward_model_name, trust_remote_code=True
        )
        models = [client.models.list() for client in self.clients]
        self.models = [model.data[0].id for model in models]
        self.num_trials = num_trials
        self.is_sigmoid = is_sigmoid

    def get_reward(self, problem: str, answer: str) -> float:
        chat = [
            {
                "role": "system",
                "content": "Please reason step by step, and put your final answer within \\boxed{}.",
            },
            {
                "role": "user",
                "content": problem,
            },
            {
                "role": "assistant",
                "content": answer,
            },
        ]

        conversation_str = self.tokenizer.apply_chat_template(
            chat, tokenize=False, add_generation_prompt=False
        )
        if len(self.tokenizer.encode(conversation_str)) > MAX_TOKEN_LENGTH:
            conversation_str = self.tokenizer.decode(
                self.tokenizer.encode(conversation_str)[:MAX_TOKEN_LENGTH]
            )
            print(f"Truncated conversation to {MAX_TOKEN_LENGTH} tokens")

        selected_client_idx = random.randint(0, len(self.clients) - 1)

        base_delay = 10
        for i in range(self.num_trials):
            try:
                responses = self.clients[selected_client_idx].embeddings.create(
                    input=[conversation_str],
                    model=self.models[selected_client_idx],
                )
            except Exception as e:
                print(f"Error: {e}")
                # Exponential Backoff and Jitter
                # See: https://aws.amazon.com/jp/blogs/architecture/exponential-backoff-and-jitter/
                #
                # Exponential Backoff:
                #   2 ** i
                #
                # Jitter:
                #   Multiply by a random factor between 0.5 and 1.5
                exp_factor = 2**i
                jitter_factor = random.uniform(0.5, 1.5)

                # Add base_delay
                current_backoff = exp_factor * jitter_factor
                final_delay = base_delay + current_backoff

                print(
                    f"Sleeping {final_delay:.2f} seconds before retry (attempt {i+1}/{self.num_trials})..."
                )
                time.sleep(final_delay)
                selected_client_idx = (selected_client_idx + 1) % len(self.clients)
                continue
            break

        if self.is_sigmoid:
            return sigmoid(responses.data[0].embedding[-1])
        else:
            return responses.data[0].embedding[-1]
