from typing import List, Dict, Any
import numpy as np
from transformers import AutoTokenizer
from utils.llm import OpenAIClient
from utils.concurrency import run_batch
from search.decoding import ResponseState

def make_multi_turn_messages(question: str, outputs: List[str]):
    def make_messages(question: str, response: str):
        if "</think>" in response:
            response = response.split("</think>")[-1]
        messages = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": response},
            {"role": "user", "content": f'Please verify the solution step by step. At the end of the solution verification, when you give your final grade, write it in the form "Is the answer correct (Yes/No)? X", where X is either Yes or No.'},
        ]
        return messages
    return [make_messages(question, output) for output in outputs]

def get_multi_turn_verification_score(state: ResponseState) -> float:
    logprobs = state.meta_info.get("result", {}).get("output_top_logprobs", [])
    if not logprobs:
        return -1.0
    last_token_logprob = logprobs[-1]
    if len(last_token_logprob) < 2:
        first_logprob, _, first_token_text = last_token_logprob[0]
        prob_gap = np.exp(first_logprob)
    else:
        first_logprob, first_token_id, first_token_text = last_token_logprob[0]
        second_logprob, second_token_id, second_token_text = last_token_logprob[1]
        prob_gap = np.exp(first_logprob) - np.exp(second_logprob)
        
    score = -1.0
    if "yes" in first_token_text.lower():
        score = prob_gap
    elif "no" in first_token_text.lower():
        score = -prob_gap
    return score

VERIFIER_MESSAGES_BUILDERS = {
    "multi-turn": make_multi_turn_messages,
}

VERIFIER_STOP_TEMPLATES = {
    "multi-turn": ["Is the answer correct (Yes/No)? "],
}

VERIFIER_SCORE_FUNCTIONS = {
    "multi-turn": get_multi_turn_verification_score,
}

def verifier_rollout_func(llm: OpenAIClient, prompt: str, temperature: float = 0.0, max_tokens: int = 2048, stop: List[str] = [], retries: int = 10) -> ResponseState:
    state = ResponseState(prompt=prompt)
    for attempt in range(retries):
        try:
            res = llm.generate(
                prompt=prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                stop=stop,
                logprobs=2,
            )
            state.response_text = res["text"]
            state.meta_info["result"] = {
                "finish_reason": {"type": res["finish_reason"]},
                "output_text": res["text"],
            }
            if "logprobs" in res:
                state.meta_info["result"]["output_top_logprobs"] = []
                if hasattr(res["logprobs"], 'content') and res["logprobs"].content:
                    for token_lp in res["logprobs"].content:
                        if hasattr(token_lp, 'top_logprobs'):
                            state.meta_info["result"]["output_top_logprobs"].append(
                                [[tp.logprob, 0, tp.token] for tp in token_lp.top_logprobs]
                            )
                        else:
                            state.meta_info["result"]["output_top_logprobs"].append([[token_lp.logprob, 0, token_lp.token]])
            break
        except Exception as e:
            print(f"Attempt {attempt} failed with error: {e}")
            if attempt == retries - 1:
                raise e
    return state

class GenerativeVerifier:
    def __init__(self, args: Any):
        self.args = args
        self.llm = OpenAIClient(base_url=args.generative_verifier_url, api_key=args.generative_verifier_api_key, model=args.generative_verifier_model_path)
        
        self.model_path = args.generative_verifier_model_path
        self.template = args.generative_verifier_template
        self.temperature = args.generative_verifier_temperature
        self.max_tokens = args.max_tokens
        self.max_threads = args.max_threads
        self.max_prompt_length = args.generative_verifier_max_prompt_length
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.make_messages_fn = VERIFIER_MESSAGES_BUILDERS[self.template]
        self.stop_strs = VERIFIER_STOP_TEMPLATES[self.template]
        self.score_fn = VERIFIER_SCORE_FUNCTIONS[self.template]
    
    def score(self, question: str, outputs: List[str]) -> List[List[float]]:
        messages = self.make_messages_fn(question, outputs)
        prompts = [self.tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize=False) for msg in messages]
        prompt_tokens = [len(self.tokenizer.encode(prompt)) for prompt in prompts]
        valid_prompts = [prompt for prompt, token in zip(prompts, prompt_tokens) if token <= self.max_prompt_length]
        valid_prompts_idx = [i for i, token in enumerate(prompt_tokens) if token <= self.max_prompt_length]
        
        run_params = [dict(llm=self.llm, prompt=prompt, temperature=self.temperature, max_tokens=self.max_tokens, stop=self.stop_strs) for prompt in valid_prompts]
        states = run_batch(verifier_rollout_func, run_params, num_threads=len(run_params))
        scores = [self.score_fn(state) for state in states]
        ret_scores = [0.0] * len(prompts)
        for i, idx in enumerate(valid_prompts_idx):
            ret_scores[idx] = scores[i]
        return [[score] for score in ret_scores]

    def __call__(self, question: str, outputs: List[str]) -> List[List[float]]:
        return self.score(question, outputs)
