import re
from typing import Dict

from termcolor import colored
from prompts.vera_prompts import get_vera_prompt, VERA_ANSWER_SYMBOL
from dllm_generation import dllm_original_generation


class DllmRewardModel(object):
    def __init__(
        self,
        *,
        dllm_type: str,
        model,
        tokenizer,
        device: str = "cuda",
        rule_format_string: str = None,
        gen_steps: int = 64,
        max_new_tokens: int = 128,
        llada_block_length: int = 32,
        temperature: float = 0.0,
        dream_top_p: float = 0.95,
        dream_alg: str = "entropy",
        dream_alg_temp: float = 0.0,
    ):
        """
        Reward model that uses the same DLLM backend to produce verifier responses.

        The verifiers are the same as in rewards/reward.py but generation uses DLLM.
        """
        self.dllm_type = dllm_type
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.rule_format_string = rule_format_string

        # Small, fast generation defaults for verifiers
        self.gen_kwargs = {
            "steps": gen_steps,
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
        }
        if dllm_type == "llada":
            self.gen_kwargs.update({"block_length": llada_block_length})
        elif dllm_type == "dream":
            self.gen_kwargs.update({
                "top_p": dream_top_p,
                "alg": dream_alg,
                "alg_temp": dream_alg_temp,
            })

    def load_domain_specific_verifiers(self):
        return [
            "calculation_check",
            "answer_correct",
            "answer_completeness",
            "understanding_check",
        ]

    def get_verifications(self, question: str, solution: str) -> Dict[str, bool]:
        veras = self.load_domain_specific_verifiers()
        verifications = {}
        for vera_type in veras:
            vera_prompt = get_vera_prompt(vera_type, question, solution)
            # Use DLLM generation for verifier
            response, _ = dllm_original_generation(
                dllm_type=self.dllm_type,
                model=self.model,
                tokenizer=self.tokenizer,
                device=self.device,
                input_text=vera_prompt,
                **self.gen_kwargs,
            )
            verifications[vera_type] = self.extract_verifier_approval(response)
        return verifications

    def extract_verifier_approval(self, verifier_response: str) -> bool:
        verb = VERA_ANSWER_SYMBOL.lower()
        pattern = re.compile(r'.*{}(.*)'.format(re.escape(verb)), flags=re.DOTALL | re.IGNORECASE)
        match = pattern.search(verifier_response)
        answer = match.group(1).strip() if match else None
        if not answer:
            print(colored(
                f"WARNING in extract_verifier_approval: answer not found. Full response (len={len(verifier_response)}):\n"
                f"{'-' * 30}\n{verifier_response}\n{'-' * 30}",
                "yellow",
            ))
            return False

        answer = answer.replace("*", "").strip().lower()
        if "true" in answer:
            return True
        if "false" in answer:
            return False

        # fallback: first word contains true/false
        first_word = answer.split()[0] if answer else ""
        if "true" in first_word:
            return True
        if "false" in first_word:
            return False

        print(colored(
            "WARNING in extract_verifier_approval: failed to parse boolean. Treat as False.",
            "yellow",
        ))
        return False

    def get_reward(self, question: str, solution: str, input_text: str = None) -> float:
        verifications = self.get_verifications(question, solution)
        reward = 0
        reward_list = self.get_reward_list()
        total = 0
        for name, ok in verifications.items():
            total += reward_list[name]
            if ok:
                print(colored(f"Verifier {name} approved the solution.", "green"))
            else:
                print(colored(f"Verifier {name} disapproved the solution.", "red"))
                reward -= reward_list[name]

        if self.rule_format_string is not None:
            if not self.get_rule_format_verify(solution):
                print(colored("Verifier Rule Format disapproved the solution.", "red"))
                reward += -2
            else:
                print(colored("Verifier Rule Format approved the solution.", "green"))

        return reward / total if total > 0 else 0.0

    def get_rule_format_verify(self, solution: str) -> bool:
        pattern = self.rule_format_string
        if not pattern:
            return True
        matches = list(re.finditer(pattern, solution, re.DOTALL))
        return len(matches) > 0

    def get_reward_answer_only(self, question: str, solution: str) -> float:
        verifications = self.get_verifications(question, solution)
        reward = 0
        reward_list = self.get_reward_list()
        total = 0
        for name, ok in verifications.items():
            total += reward_list[name]
            if not ok:
                reward -= reward_list[name]
        return reward / total if total > 0 else 0.0

    def get_reward_list(self) -> Dict[str, int]:
        return {
            "calculation_check": 2,
            "answer_correct": 1,
            "answer_completeness": 2,
            "understanding_check": 1,
        }
