from typing import List, Dict

import numpy as np
import sglang as sgl
from transformers import AutoTokenizer, PreTrainedTokenizer
from sglang.lang.interpreter import ProgramState

def make_generative_rm_messages(prompt: str, outputs: List[str]):
    def make_messages(prompt: str, response: str):
        if "</think>" in response:
            response = response.split("</think>")[-1]
        messages = [
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": response},
            {"role": "user", "content": "Please act as an impartial judge and evaluate the quality of the assistant's response. A preferred response is helpful, harmless, and accurately follows instructions. Is this a preferred response? Answer 'Yes' or 'No' in the format 'Preferred: X'."},
        ]
        return messages
    return [make_messages(prompt, output) for output in outputs]


def get_generative_rm_verification_score(state: ProgramState) -> List[float]:
    logprobs = state.get_meta_info("result")["output_top_logprobs"]
    last_token_logprob = logprobs[-1]
    first_logprob, first_token_id, first_token_text = last_token_logprob[0]
    second_logprob, second_token_id, second_token_text = last_token_logprob[1]
    prob_gap = np.exp(first_logprob) - np.exp(second_logprob)
    score = -1.0
    if "yes" in first_token_text.lower():
        score = prob_gap
    elif "no" in first_token_text.lower():
        score = -prob_gap
    return score


@sgl.function
def verifier_rollout(s: ProgramState, prompt, temperature=0.0, max_tokens=2048, stop=[], retries=10):
    s += prompt
    for attempt in range(retries):
        s += sgl.gen("result", temperature=temperature, max_tokens=max_tokens, stop=stop, return_logprob=True, top_logprobs_num=2, return_text_in_logprobs=True)
        if s.error() is None:
            break
        else:
            print(f"Attempt {attempt} failed with error: {s.error()}")


class GenerativeRMServerVerifier:
    def __init__(self, args, tokenizer: PreTrainedTokenizer):
        self.args = args
        self.endpoint = sgl.RuntimeEndpoint(base_url=args.base_url, api_key=args.api_key)
        
        self.temperature = args.temperature
        self.max_prompt_length = args.prompt_length
        self.max_tokens = args.max_tokens
        self.max_threads = args.max_threads

        self.tokenizer = tokenizer
        self.make_messages_fn = make_generative_rm_messages
        self.stop_strs = ["Preferred: "]
        self.score_fn = get_generative_rm_verification_score
    
    def score(self, sample: Dict, outputs: List[str]) -> List[float]:
        prompt = sample["_pref_prompt"]
        messages = self.make_messages_fn(prompt, outputs)
        prompts = [self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) for messages in messages]
        prompt_tokens = [len(self.tokenizer.encode(prompt)) for prompt in prompts]
        valid_prompts = [prompt for prompt, token in zip(prompts, prompt_tokens) if token <= self.max_prompt_length]
        valid_prompts_idx = [i for i, token in enumerate(prompt_tokens) if token <= self.max_prompt_length]
        
        run_params = [dict(prompt=prompt, temperature=self.temperature, max_tokens=self.max_tokens, stop=self.stop_strs) for prompt in valid_prompts]
        states = verifier_rollout.run_batch(run_params, num_threads=self.max_threads, backend=self.endpoint)
        scores = [self.score_fn(state) for state in states]
        ret_scores = [0.0] * len(prompts)
        for i, idx in enumerate(valid_prompts_idx):
            ret_scores[idx] = scores[i]
        return [[score] for score in ret_scores]

    def __call__(self, sample: Dict, outputs: List[str]) -> List[float]:
        return self.score(sample, outputs)
