from typing import List, Dict

import numpy as np
import sglang as sgl
from transformers import PreTrainedTokenizer
from sglang.lang.interpreter import ProgramState


@sgl.function
def get_input_logprob(s: ProgramState, prompt, retries=10, top_logprobs_num=1):
    s += prompt
    for attempt in range(retries):
        s += sgl.gen("result", max_tokens=0, return_logprob=True, logprob_start_len=0, top_logprobs_num=top_logprobs_num, return_text_in_logprobs=True)
        if s.error() is None:
            break
        else:
            print(f"Attempt {attempt} failed with error: {s.error()}")


class PreferenceVerifier:
    def __init__(self, args, tokenizer: PreTrainedTokenizer):
        self.args = args
        self.endpoint = sgl.RuntimeEndpoint(base_url=args.base_url, api_key=args.api_key)
        self.max_threads = args.max_threads
        self.mixed_weight = args.mixed_weight
        self.discount_factor = args.discount_factor
        self.tokenizer = tokenizer
    
    def score(self, sample: Dict, outputs: List[str]) -> List[float]:
        pref_prompt = sample["_pref_prompt"]
        non_pref_prompt = sample["_non_pref_prompt"]

        pref_logprobs = self._get_response_logprob(pref_prompt, outputs)
        non_pref_logprobs = self._get_response_logprob(non_pref_prompt, outputs)

        r1 = [np.array(non_pref_logprob) for non_pref_logprob in non_pref_logprobs]
        r2 = [np.array(pref_logprob) - np.array(non_pref_logprob) for pref_logprob, non_pref_logprob in zip(pref_logprobs, non_pref_logprobs)]
        rewards = [rr1 + self.mixed_weight * rr2 for rr1, rr2 in zip(r1, r2)]
        rewards = [reward * self.discount_factor ** i for i, reward in enumerate(rewards)]
        return rewards

    def _get_response_logprob(self, prompt, responses):
        texts = [prompt + response for response in responses]
        token_lens = [len(self.tokenizer.encode(text)) for text in texts]
        prompt_len = len(self.tokenizer.encode(prompt))
        response_lens = [token_len - prompt_len for token_len in token_lens]
        
        states = get_input_logprob.run_batch([dict(prompt=text) for text in texts], num_threads=self.max_threads, backend=self.endpoint)
        logprobs = [state.get_meta_info("result")["input_token_logprobs"] for state in states]
        logprobs = [[prob[0] for prob in logprob] for logprob in logprobs]
        logprobs = [logprob[-response_len:] for logprob, response_len in zip(logprobs, response_lens)]
        return logprobs

    def __call__(self, sample: Dict, outputs: List[str]) -> List[float]:
        return self.score(sample, outputs)
