import torch
import torch.nn.functional as F

from typing import Dict, List, Tuple
from parse import parse
import numpy as np
import logging
import time
import threading
from transformers import pipeline

from baselines.skywork_prm.skywork_prm import SkyworkO1_7B, SkyworkO1_1_5B
from lm_polygraph.stat_calculators.extract_claims import Claim
from lm_polygraph.stat_calculators.stat_calculator import StatCalculator
from lm_polygraph.utils.model import Model
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from baselines.skywork_prm.io_utils import prepare_batch_input_for_model, derive_step_rewards, prepare_input

log = logging.getLogger()


class PRMStatCalculator(StatCalculator):
    def __init__(
            self,
            prompt_path: str | None = None,
            model_path: str = "Qwen/Qwen2.5-Math-7B-PRM800K",
            device: str = "auto",
            scores_key: str = "prm_scores",
    ):
        self.model_path = model_path
        self.model_id = self.model_path.split('/')[-1] if '/' in self.model_path else self.model_path
        self.device = device
        self.prm_tokenizer = None
        self.prm_model = None
        self.prompt = open(prompt_path, 'r').read() if prompt_path else "{q}"
        self.scores_key = scores_key

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return ["prm_scores"], ["claims"]

    def init(self):
        if self.prm_model is not None:
            return
        device = self.device
        log.info(f"Initializing {self.model_path} model on device={self.device}")
        self.prm_tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.prm_model = AutoModel.from_pretrained(
            self.model_path,
            device_map=device,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()

    def make_step_rewards(self, logits, token_masks):
        self.init()
        probabilities = F.softmax(logits, dim=-1)
        probabilities = probabilities * token_masks.unsqueeze(-1)
        all_scores_res = []
        for i in range(probabilities.size(0)):
            sample = probabilities[i]
            positive_probs = sample[sample != 0].view(-1, 2)[:, 1]
            non_zero_elements_list = positive_probs.cpu().tolist()
            all_scores_res.append(non_zero_elements_list)
        return all_scores_res

    def get_rewards(self, question: str, steps: list[Claim]) -> list[float]:
        self.init()
        if len(steps) == 0:
            return []
        messages = [
            {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
            {"role": "user", "content": question},
            {"role": "assistant", "content": "<extra_0>".join([c.claim_text for c in steps]) + "<extra_0>"},
        ]
        conversation_str = self.prm_tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        )
        input_ids = self.prm_tokenizer.encode(conversation_str, return_tensors="pt").to(self.prm_model.device)
        with torch.no_grad():
            outputs = self.prm_model(input_ids=input_ids)
        step_sep_id = self.prm_tokenizer.encode("<extra_0>")[0]
        token_masks = (input_ids == step_sep_id)
        step_reward = self.make_step_rewards(outputs[0], token_masks)
        return step_reward[0]

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100,
                 **kwargs) -> Dict[str, np.ndarray]:
        self.init()
        rewards: list[list[float]] = []
        for input_text, claims in zip(texts, dependencies["claims"]):
            question = parse(self.prompt, input_text).named['q']
            r = self.get_rewards(question, claims)
            assert len(r) == len(claims)
            rewards.append(r)
        return {self.scores_key: rewards}


class MathShepherdPRMCalculator(StatCalculator):
    def __init__(
            self,
            prompt_path: str | None = None,
            model_path: str = "peiyi9979/math-shepherd-mistral-7b-prm",
            device: str = "auto",
            scores_key: str = "prm_scores",
    ):
        self.model_path = model_path
        self.model_id = self.model_path.split('/')[-1] if '/' in self.model_path else self.model_path
        self.device = device
        if device == "auto":
            self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.tokenizer = None
        self.model = None
        self.prompt = open(prompt_path, 'r').read() if prompt_path else "{q}"
        self.step_tag = "ки"
        self.good_token = "+"
        self.bad_token = "-"
        self.step_tag_id = None
        self.candidate_token_ids = None
        self.scores_key = scores_key

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return ["prm_scores"], ["claims"]

    def init(self):
        if self.model is not None:
            return
        device = self.device
        log.info(f"Initializing {self.model_path} model on device={device}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_path).to(device).eval()
        self.step_tag_id = self.tokenizer.encode(self.step_tag)[-1]
        self.candidate_token_ids = self.tokenizer.encode(f"{self.good_token} {self.bad_token}")[1:]  # skip BOS

    def get_rewards(self, question: str, steps: list[Claim]) -> list[float]:
        self.init()

        if len(steps) == 0:
            return []

        # Reconstruct the output with step separator
        output_text = ""
        for i, step in enumerate(steps):
            output_text += f"Step {i + 1}: {step.claim_text.strip()} {self.step_tag}\n"
        input_text = f"{question.strip()} {output_text.strip()}"

        input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            logits = self.model(input_ids).logits[:, :, self.candidate_token_ids]
            probs = F.softmax(logits, dim=-1)[:, :, 0]  # probability of '+'

        step_mask = input_ids == self.step_tag_id
        step_scores = probs[step_mask]

        return step_scores.cpu().tolist()

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100,
                 **kwargs) -> Dict[str, np.ndarray]:
        self.init()
        rewards: list[list[float]] = []
        for input_text, claims in zip(texts, dependencies["claims"]):
            question = parse(self.prompt, input_text).named['q']
            r = self.get_rewards(question, claims)
            assert len(r) == len(claims)
            rewards.append(r)
        return {self.scores_key: rewards}


class SkyworkPRMStatCalculator(StatCalculator):
    def __init__(
            self,
            prompt_path: str | None = None,
            model_path: str = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B",
            device: str = "auto",
            scores_key: str = "prm_scores",
    ):
        self.prompt = open(prompt_path, "r").read() if prompt_path else "{q}"
        self.model_path = model_path
        self.model_id = self.model_path.split('/')[-1] if '/' in self.model_path else self.model_path
        self.device = device
        self.tokenizer = None
        self.model = None
        self.step_token = "\n"
        self.scores_key = scores_key

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return ["prm_scores"], ["claims"]

    def init(self):
        if self.model is not None:
            return
        print(f"Initializing Skywork PRM model from {self.model_path}...")

        start_time = time.time()

        def keep_alive(start_time):
            while not done_loading[0]:
                elapsed = time.time() - start_time
                print(f"\rStill loading... (elapsed: {elapsed:.1f}s)", flush=True)
                time.sleep(1)

        done_loading = [False]
        thread = threading.Thread(target=keep_alive, args=(start_time,))
        thread.start()

        if '1.5B' in self.model_path:
            cls = SkyworkO1_1_5B
        else:
            cls = SkyworkO1_7B
        self.model, self.tokenizer = cls.load_model_and_tokenizer()

        done_loading[0] = True
        thread.join()
        print('Done!')

    def get_rewards(self, question: str, steps: List[Claim]) -> List[float]:
        self.init()
        if len(steps) == 0:
            return []

        # Reconstruct full response from Claim list
        response = self.step_token.join([step.claim_text.strip() for step in steps])

        processed = prepare_input(question, response, tokenizer=self.tokenizer, step_token=self.step_token)
        input_ids, step_locs, reward_flags = processed

        # Prepare batch-compatible inputs
        input_ids_batch, attention_mask, reward_flags = prepare_batch_input_for_model(
            [input_ids], [reward_flags], self.tokenizer.pad_token_id
        )

        device = self.model.pretrained_model.device
        with torch.no_grad():
            _, _, rewards = self.model(
                input_ids=input_ids_batch.to(device),
                attention_mask=attention_mask.to(device),
                return_probs=True,
            )

        step_rewards = derive_step_rewards(rewards.detach().to("cpu", dtype=torch.float32), reward_flags)
        r = step_rewards[0]
        if isinstance(r, list):
            return r
        return r.tolist()

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100,
                 **kwargs) -> Dict[str, np.ndarray]:
        self.init()
        rewards: list[list[float]] = []
        for input_text, claims in zip(texts, dependencies["claims"]):
            question = parse(self.prompt, input_text).named["q"]
            r = self.get_rewards(question, claims)
            assert len(r) == len(claims)
            rewards.append(r)
        return {self.scores_key: rewards}


# --- add near the other imports at the top (no new deps needed) ---
# (nothing extra required)

# --- add this new class alongside the other calculators ---
class RLHFlowLlama31PRMCalculator(StatCalculator):
    """
    PRM adapter that matches RLHFlow's official evaluator logic:
    - Build a chat with each step as a user turn, followed by assistant '+'
    - For each step (one forward pass), read P('+' | context) from logits
    - Matches https://github.com/RLHFlow/RLHF-Reward-Modeling/blob/main/math-rm/prm_evaluate.py
    """
    def __init__(
        self,
        prompt_path: str | None = None,
        model_path: str = "RLHFlow/Llama3.1-8B-PRM-Mistral-Data",
        device: str = "auto",
        scores_key: str = "prm_scores",
    ):
        self.model_path = model_path
        self.model_id = self.model_path.split('/')[-1] if '/' in self.model_path else self.model_path
        self.device = device if device != "auto" else ("cuda:0" if torch.cuda.is_available() else "cpu")
        self.tokenizer = None
        self.model = None
        self.prompt = open(prompt_path, 'r').read() if prompt_path else "{q}"
        # tokens for '+' and '-'
        self.plus_token_id = None
        self.minus_token_id = None
        self.candidate_token_ids = None
        self.scores_key = scores_key

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return ["prm_scores"], ["claims"]

    def init(self):
        if self.model is not None:
            return
        log.info(f"Initializing {self.model_path} model on device={self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        # match RLHFlow dtype + CausalLM head
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path, torch_dtype=torch.bfloat16
        ).to(self.device).eval()

        # padding settings used in the official script
        self.tokenizer.padding_side = "right"
        if getattr(self.tokenizer, "pad_token", None) is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if getattr(self.model.config, "pad_token_id", None) is None:
            self.model.config.pad_token_id = self.model.config.eos_token_id

        # candidate ids for classification
        self.plus_token_id = self.tokenizer.encode("+")[-1]
        self.minus_token_id = self.tokenizer.encode("-")[-1]
        self.candidate_token_ids = [self.plus_token_id, self.minus_token_id]

    def _score_last_plus(self, input_ids: torch.Tensor, logits: torch.Tensor) -> float:
        """
        Read P('+') at the last assistant '+' token, matching RLHFlow's logic.
        Primary path: use the official '-3' index.
        Fallback: locate the last '+' token in input_ids and use logits at (pos - 1).
        """
        # official simple indexing (template-dependent)
        idx = -3
        # fallback: find actual '+' position if available
        try:
            plus_positions = (input_ids[0] == self.plus_token_id).nonzero(as_tuple=True)[0]
            if len(plus_positions) > 0:
                idx = int(plus_positions[-1].item()) - 1  # logits[t] predicts token at t+1
        except Exception:
            pass

        cand_logits = logits[:, idx, self.candidate_token_ids]  # shape [1, 2]
        probs = F.softmax(cand_logits, dim=-1)[:, 0]            # P('+')
        return probs[0].detach().to('cpu', dtype=torch.float32).item()

    def get_rewards(self, question: str, steps: list[Claim]) -> list[float]:
        """
        One forward pass per step, like RLHFlow's evaluator:
        conversation: [user: (question + step1), assistant: '+', user: step2, assistant: '+', ...]
        return a list of P('+') for each step.
        """
        self.init()
        if not steps:
            return []

        rewards: list[float] = []
        conversation: list[Dict[str, str]] = []

        for k, step in enumerate(steps):
            if k == 0:
                text = f"{question.strip()} {step.claim_text.strip()}"
            else:
                text = step.claim_text.strip()
            conversation.append({"role": "user", "content": text})
            conversation.append({"role": "assistant", "content": "+"})

            # teacher-forced scoring for the last '+'
            input_ids = self.tokenizer.apply_chat_template(
                conversation, return_tensors="pt"
            ).to(self.model.device)

            with torch.no_grad():
                logits = self.model(input_ids).logits  # [1, T, V]

            rewards.append(self._score_last_plus(input_ids, logits))

        return rewards

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model,
                 max_new_tokens: int = 100, **kwargs) -> Dict[str, np.ndarray]:
        self.init()
        out: list[list[float]] = []
        for input_text, claims in zip(texts, dependencies["claims"]):
            question = parse(self.prompt, input_text).named['q']
            r = self.get_rewards(question, claims)
            assert len(r) == len(claims)
            out.append(r)
        return {self.scores_key: out}


class UniversalPRMCalculator(StatCalculator):
    def __init__(
        self,
        prompt_path: str | None = None,
        model_path: str = "universalprm/Universal-PRM",
        device: str = "auto",
        scores_key: str = "prm_scores",
    ):
        super().__init__()
        self.model_path = model_path
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device  # "cuda" or "cpu"
        self.tokenizer = None
        self.model = None
        self.prompt = open(prompt_path, 'r').read() if prompt_path else "{q}"
        self.separator = "\n\n"
        self.scores_key = scores_key

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return ["prm_scores"], ["claims"]

    def init(self):
        if self.model is not None:
            return
        log.info(f"Initializing {self.model_path} on device={self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        # The model card uses AutoModel + device_map=device + bf16 + trust_remote_code=True
        self.model = AutoModel.from_pretrained(
            self.model_path,
            device_map=self.device,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()

    def _build_messages(self, question: str, reference_answer: str | None = None):
        # Model card pattern: append "The reference answer is: ..." to the user message.
        if reference_answer and reference_answer.strip():
            q_wgt = question + "\n\n###\n\nThe reference answer is: " + reference_answer
        else:
            q_wgt = question + "\n\n###\n\nThe reference answer is: There is no reference answer for this question."
        return [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": q_wgt},
        ]

    def get_rewards(self, question: str, steps: list[Claim]) -> list[float]:
        self.init()
        if not steps:
            return []

        messages = self._build_messages(question, reference_answer=None)
        query_ids = self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True
        )

        rewards: list[float] = []
        with torch.no_grad():
            for k in range(1, len(steps) + 1):
                responses = self.separator.join([s.claim_text.strip() for s in steps[:k]]) + self.separator
                answer_tokens = self.tokenizer(responses)["input_ids"]
                answer_tokens += [self.tokenizer.eos_token_id]
                qa_ids = query_ids + answer_tokens

                input_ids = torch.tensor([qa_ids], dtype=torch.long)
                if self.device == "cuda":
                    input_ids = input_ids.cuda()
                outputs = self.model(input_ids=input_ids)
                reward = torch.sigmoid(outputs[0]).detach().to("cpu", dtype=torch.float32).item()
                rewards.append(float(reward))
        return rewards

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model,
                 max_new_tokens: int = 100, **kwargs) -> Dict[str, np.ndarray]:
        self.init()
        out: list[list[float]] = []
        for input_text, claims in zip(texts, dependencies["claims"]):
            question = parse(self.prompt, input_text).named["q"]
            r = self.get_rewards(question, claims)
            assert len(r) == len(claims)
            out.append(r)
        return {self.scores_key: out}


class H4Qwen25Math15BPRM02Calculator(StatCalculator):
    def __init__(
        self,
        prompt_path: str | None = None,
        model_path: str = "HuggingFaceH4/Qwen2.5-Math-1.5B-Instruct-PRM-0.2",
        device: str = "auto",
        scores_key: str = "prm_scores",
    ):
        super().__init__()
        self.model_path = model_path
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device  # "cuda" or "cpu"
        self.pipe = None
        self.prompt = open(prompt_path, 'r').read() if prompt_path else "{q}"
        self.separator = "\n\n"  # IMPORTANT: must match training separator per model card
        self.scores_key = scores_key

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return ["prm_scores"], ["claims"]

    def init(self):
        if self.pipe is not None:
            return
        log.info(f"Initializing {self.model_path} token-classification pipeline on device={self.device}")
        # The card uses: pipeline("token-classification", model=..., device="cuda")
        self.pipe = pipeline("token-classification", model=self.model_path, device=self.device)

    def _p_true_from_last_entity(self, preds: List[Dict]) -> float:
        if not preds:
            return 0.5  # neutral fallback
        last = preds[-1]
        label = last.get("entity", "")
        score = float(last.get("score", 0.5))
        if label == "LABEL_1":
            return score
        elif label == "LABEL_0":
            return 1.0 - score
        else:
            # Unexpected label name; default to neutral
            return 0.5

    def get_rewards(self, question: str, steps: list[Claim]) -> list[float]:
        self.init()
        if not steps:
            return []

        rewards: list[float] = []
        for k in range(1, len(steps) + 1):
            prefix_steps = [s.claim_text.strip() for s in steps[:k]]
            text = self.separator.join([question] + prefix_steps) + self.separator
            preds = self.pipe(text)
            rewards.append(self._p_true_from_last_entity(preds))
        return rewards

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model,
                 max_new_tokens: int = 100, **kwargs) -> Dict[str, np.ndarray]:
        self.init()
        out: list[list[float]] = []
        for input_text, claims in zip(texts, dependencies["claims"]):
            question = parse(self.prompt, input_text).named["q"]
            r = self.get_rewards(question, claims)
            assert len(r) == len(claims)
            out.append(r)
        return {self.scores_key: out}


def load_prm_calculator_by_model_path(
        prompt_path: str | None = None,
        model_path: str = "Qwen/Qwen2.5-Math-7B-PRM800K",
        device: str = "auto",
        scores_key: str = "prm_scores",
):
    if model_path.startswith("Qwen/"):
        return PRMStatCalculator(
            prompt_path=prompt_path,
            model_path=model_path,
            device=device,
            scores_key=scores_key,
        )
    elif model_path.startswith("peiyi9979/"):
        return MathShepherdPRMCalculator(
            prompt_path=prompt_path,
            model_path=model_path,
            device=device,
            scores_key=scores_key,
        )
    elif model_path.startswith("RLHFlow/Llama3.1-8B-PRM-"):
        return RLHFlowLlama31PRMCalculator(
            prompt_path=prompt_path,
            model_path=model_path,
            device=device,
            scores_key=scores_key,
        )
    elif model_path.startswith("universalprm/"):
        return UniversalPRMCalculator(
            prompt_path=prompt_path,
            model_path=model_path,
            device=device,
            scores_key=scores_key,
        )
    elif model_path == "HuggingFaceH4/Qwen2.5-Math-1.5B-Instruct-PRM-0.2":
        return H4Qwen25Math15BPRM02Calculator(
            prompt_path=prompt_path,
            model_path=model_path,
            device=device,
            scores_key=scores_key,
        )
    elif "Skywork" in model_path:
        return SkyworkPRMStatCalculator(
            prompt_path=prompt_path,
            model_path=model_path,
            device=device,
            scores_key=scores_key,
        )
    elif model_path.startswith("GenPRM/"):
        from baselines.gen_prm import GenPRMStatCalculator, GenPRMStatCalculatorSimple
        if model_path.endswith('-simple'):
            return GenPRMStatCalculatorSimple(
                prompt_path=prompt_path,
                model_path=model_path[:-len('-simple')],
                device=device,
                scores_key=scores_key,
            )
        else:
            return GenPRMStatCalculator(
                prompt_path=prompt_path,
                model_path=model_path,
                device=device,
                scores_key=scores_key,
            )
    else:
        raise ValueError(f"Unsupported model path prefix for PRM model: {model_path}")


def load_stat_calculator(config, builder):
    return PRMStatCalculator(
        prompt_path=config.prompt_path,
        model_path=config.get("model_path", "Qwen/Qwen2.5-Math-7B-PRM800K"),
        device=config.get("device", "auto"),
    )
