#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import accumulate

import logging
import torch
import torch.nn.functional as F
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from sal.config import Config
from sal.models.skywork_o1_prm.io_utils import (
    derive_step_rewards,
    prepare_batch_input_for_model,
    prepare_input,
)
from sal.models.skywork_o1_prm.prm_model import SkyworkPRMModel

CANDIDATE_TOKENS = [648, 387]
STEP_TAG_ID = 12902

logger = logging.getLogger(__name__)


def batched_math_shepherd_inference(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    inputs: list[str],
    batch_size: int,
) -> list[list[float]]:
    output_scores = []
    for i in range(0, len(inputs), batch_size):
        inputs_batch = inputs[i : i + batch_size]
        inputs_batch = tokenizer(inputs_batch, padding=True, return_tensors="pt").to(
            model.device
        )
        with torch.no_grad():
            logits = model(**inputs_batch).logits[:, :, CANDIDATE_TOKENS]
            scores = logits.softmax(dim=-1)[:, :, 0]
            step_scores_flat = scores[inputs_batch.input_ids == STEP_TAG_ID].tolist()
            # Split scores into sublist based on number of \n in the input
            step_scores = []
            counter = 0
            for i in range(len(inputs_batch.input_ids)):
                count = inputs_batch.input_ids[i].tolist().count(STEP_TAG_ID)
                step_scores.append(step_scores_flat[counter : counter + count])
                counter += count

        # Store the step scores for this batch
        output_scores.extend(step_scores)

        # Clear GPU memory
        del inputs_batch, logits, scores
        torch.cuda.empty_cache()

    return output_scores


class PRM:
    def __init__(self, search_config: Config, **model_kwargs):
        self.search_config = search_config
        self.model, self.tokenizer = self.load_model_and_tokenizer(**model_kwargs)

    def load_model_and_tokenizer(
        self, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        raise NotImplementedError

    def score(
        self, questions: list[str], outputs: list[list[str]]
    ) -> list[list[float]]:
        raise NotImplementedError


class MathShepherd(PRM):
    def load_model_and_tokenizer(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        model_id = "peiyi9979/math-shepherd-mistral-7b-prm"
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        # For batched inference
        tokenizer.pad_token = tokenizer.eos_token
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            attn_implementation="flash_attention_2",
            torch_dtype=torch.float16,
        ).eval()
        return model, tokenizer

    def score(
        self, questions: list[str], outputs: list[list[str]]
    ) -> list[list[float]]:
        inputs_for_prm = []
        lengths = []
        for question, output in zip(questions, outputs):
            prompt = self.search_config.system_prompt + "\n" + question + "\n"
            special_outputs = [o.replace("\n\n", " ки\n\n") for o in output]
            special_outputs = [
                o + " ки" if o[-2:] != "\n\n" else o for o in special_outputs
            ]
            inputs_for_prm.extend([f"{prompt} {o}" for o in special_outputs])
            lengths.append(len(output))

        # TODO: tokenize each batch independently so there is less padding and faster inference
        output_scores = batched_math_shepherd_inference(
            self.model,
            self.tokenizer,
            inputs_for_prm,
            self.search_config.prm_batch_size,
        )
        cumulative_lengths = list(accumulate(lengths))
        # reshape the output scores to match the input
        output_scores = [
            output_scores[i:j]
            for i, j in zip([0] + cumulative_lengths[:-1], cumulative_lengths)
        ]

        # stripped_output_scores = [] TODO: strip out the reward for previous steps
        for output_score, output in zip(output_scores, outputs):
            assert len(output_score) == len(output), (
                f"{len(output_score)} != {len(output)}"
            )

        return output_scores


class RLHFFlow(PRM):
    def load_model_and_tokenizer(
        self, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        tokenizer = AutoTokenizer.from_pretrained(
            "/path/to/models/llama-3.1-8b-prm-deepseek-data"
        )
        model = AutoModelForCausalLM.from_pretrained(
            "/path/to/models/llama-3.1-8b-prm-deepseek-data",
            device_map="auto",
            torch_dtype=torch.bfloat16,
            **model_kwargs,
        ).eval()
        tokenizer.padding_side = "right"
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id

        plus_tag_id = tokenizer.encode("+")[-1]
        minus_tag_id = tokenizer.encode("-")[-1]
        self.candidate_tokens = [plus_tag_id, minus_tag_id]

        return model, tokenizer

    # def score(
    #     self,
    #     questions: list[str],
    #     outputs: list[list[str]],
    #     batched: bool = True,
    #     batch_size=8,
    # ) -> list[list[float]]:
    #     if batched is True:
    #         return self._score_batched(questions, outputs, batch_size=batch_size)
    #     else:
    #         return self._score_single(questions, outputs)


    # modify the score method to handle batch size dynamically
    # This method will attempt to reduce the batch size in case of OOM errors
    def score(
        self,
        questions: list[str],
        outputs: list[list[str]],
        batched: bool = True,
        batch_size=8,
    ) -> list[list[float]]:
        if not batched:
            return self._score_single(questions, outputs)

        current_batch_size = batch_size # Start with the provided or default batch_size
        original_batch_size_for_call = batch_size # For logging purposes

        while current_batch_size >= 1:
            try:
                logger.debug(f"Attempting RLHFFlow._score_batched with batch_size: {current_batch_size}")
                # Pass the potentially reduced batch_size to _score_batched
                # _score_batched is expected to process all questions and outputs internally using the given current_batch_size
                return self._score_batched(questions, outputs, batch_size=current_batch_size)
            except torch.cuda.OutOfMemoryError as e:
                logger.warning(
                    f"CUDA out of memory in RLHFFlow._score_batched with batch_size {current_batch_size}. "
                    f"Attempting to reduce batch size. Error: {e}"
                )
                torch.cuda.empty_cache()  # Free up cached memory

                if current_batch_size == 1:
                    logger.error(
                        f"CUDA out of memory even with batch_size 1. Re-raising original OOM error for questions: {questions[:1]}..." # Log part of the problematic data
                    )
                    raise e # Re-raise the original error if batch size is already 1

                # Reduce batch size, e.g., by half, or to 1 if it's small
                if current_batch_size <= 2: # If it's 2, next try 1
                    current_batch_size = 1
                else:
                    current_batch_size //= 2
                
                logger.info(f"Retrying RLHFFlow._score_batched with new batch_size: {current_batch_size} for questions: {questions[:1]}...")
            except Exception as e: # Catch other potential errors during _score_batched
                logger.error(f"An unexpected error occurred in _score_batched with batch_size {current_batch_size} for questions: {questions[:1]}... Error: {e}")
                raise e


        # This part should ideally not be reached if OOM at batch_size=1 re-raises.
        # Adding a fallback error if the loop exits unexpectedly (e.g., if initial batch_size was < 1).
        logger.error(
            f"Failed to score batch even after reducing prm_batch_size. "
            f"Original batch_size for this call was {original_batch_size_for_call} for questions: {questions[:1]}..."
        )
        # Re-raise a generic runtime error if all attempts failed and didn't re-raise OOM.
        raise RuntimeError(f"Unable to score batch in RLHFFlow after reducing batch_size to 1. "
                           f"Initial batch_size for this call was {original_batch_size_for_call}.")


    def _score_single(self, questions: list[str], outputs: list[list[str]]):
        # reference code: https://github.com/RLHFlow/RLHF-Reward-Modeling/blob/main/math-rm/prm_evaluate.py
        all_scores = []
        for question, answers in zip(questions, outputs, strict=True):
            all_step_scores = []
            for ans in answers:
                single_step_score = []
                conversation = []
                ans_list = ans.split("\n\n")
                for k in range(len(ans_list)):
                    if k == 0:
                        # TODO: add the system prompt like we did for math shepard?
                        text = question + " " + ans_list[0]
                    else:
                        text = ans_list[k]
                    conversation.append({"content": text, "role": "user"})
                    conversation.append({"content": "+", "role": "assistant"})
                    input_ids = self.tokenizer.apply_chat_template(
                        conversation, return_tensors="pt"
                    ).to(self.model.device)
                    with torch.no_grad():
                        logits = self.model(input_ids).logits[
                            :, -3, self.candidate_tokens
                        ]  # simple version, the +/- is predicted by the '-3' position
                        step_scores = logits.softmax(dim=-1)[
                            :, 0
                        ]  # 0 means the prob of + (1 mean -)
                        # print(scores)
                        single_step_score.append(
                            step_scores[0]
                            .detach()
                            .to("cpu", dtype=torch.float32)
                            .item()
                        )

                all_step_scores.append(single_step_score)
            all_scores.append(all_step_scores)
        return all_scores

    def _score_batched(
        self, questions: list[str], outputs: list[list[str]], batch_size: int = 2
    ):
        # The RLHFlow models are trained to predict the "+" or "-" tokens in a dialogue, but since these are not unique
        # we need to introduce a dummy special token here for masking.

        special_tok_id = self.tokenizer("ки", return_tensors="pt").input_ids[0, 1]
        # We construct two parallel dialogues, one with a "+" token per assistant turn, the other with the dummy token "ки" for masking
        conversations = []
        conversations2 = []
        for question, answers in zip(questions, outputs, strict=True):
            for ans in answers:
                conversation = []
                conversation2 = []
                ans_list = ans.split("\n\n")
                for k in range(len(ans_list)):
                    if k == 0:
                        text = question + " " + ans_list[0]
                    else:
                        text = ans_list[k]
                    conversation.append({"content": text, "role": "user"})
                    conversation.append({"content": "+", "role": "assistant"})

                    # we track to location of the special token with ки in order to extract the scores
                    conversation2.append({"content": text, "role": "user"})
                    conversation2.append({"content": "ки", "role": "assistant"})

                conversations.append(conversation)
                conversations2.append(conversation2)

        output_scores = []
        for i in range(0, len(conversations), batch_size):
            convs_batch = conversations[i : i + batch_size]
            convs2_batch = conversations2[i : i + batch_size]
            inputs_batch = self.tokenizer.apply_chat_template(
                convs_batch, padding=True, return_tensors="pt"
            ).to(self.model.device)
            inputs2_batch = self.tokenizer.apply_chat_template(
                convs2_batch, padding=True, return_tensors="pt"
            ).to(self.model.device)
            assert inputs_batch.shape == inputs2_batch.shape
            with torch.no_grad():
                logits = self.model(inputs_batch).logits[:, :, self.candidate_tokens]
                scores = logits.softmax(dim=-1)[
                    :, :, 0
                ]  # 0 means the prob of + (1 mean -)

                for i in range(len(convs_batch)):
                    # We slice on the N-1 token since the model is trained to predict the Nth one ("+" in this case)
                    step_scores_flat = scores[i, :-1][
                        inputs2_batch[i, 1:] == special_tok_id
                    ].tolist()
                    output_scores.append(step_scores_flat)

        # reshape the output scores to match the input
        reshaped_output_scores = []
        counter = 0
        for question, answers in zip(questions, outputs):
            scores = []
            for answer in answers:
                scores.append(output_scores[counter])
                counter += 1
            reshaped_output_scores.append(scores)

        return reshaped_output_scores


class SkyworkO1(PRM):
    @classmethod
    def _load_model_and_tokenizer(
        cls, prm_model_path, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        tokenizer = AutoTokenizer.from_pretrained(
            prm_model_path, trust_remote_code=True
        )
        model = SkyworkPRMModel.from_pretrained(
            prm_model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            **model_kwargs,
        ).eval()

        return model, tokenizer

    def score(
        self, questions: list[str], outputs: list[list[str]]
    ) -> list[list[float]]:
        # reference code: https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B#huggingface-inference
        all_scores = []
        for question, answers in zip(questions, outputs):
            processed_data = [
                prepare_input(
                    question, answer, tokenizer=self.tokenizer, step_token="\n"
                )
                for answer in answers
            ]
            input_ids, steps, reward_flags = zip(*processed_data)
            input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
                input_ids, reward_flags, self.tokenizer.pad_token_id
            )
            device = self.model.pretrained_model.device
            with torch.no_grad():
                _, _, rewards = self.model(
                    input_ids=input_ids.to(device),
                    attention_mask=attention_mask.to(device),
                    return_probs=True,
                )
                all_step_scores = derive_step_rewards(
                    rewards.detach().to("cpu", dtype=torch.float32), reward_flags
                )
            all_scores.append(all_step_scores)
        return all_scores


class Qwen_2_5_Math(PRM):
    @classmethod
    def _load_model_and_tokenizer(
        cls, prm_model_path, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        tokenizer = AutoTokenizer.from_pretrained(
            prm_model_path, trust_remote_code=True
        )
        model = AutoModel.from_pretrained(
            prm_model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            **model_kwargs,
        ).eval()

        return model, tokenizer

    def score(
        self, questions: list[str], outputs: list[list[str]], batch_size=256
    ) -> list[list[float]]:
        current_batch_size = batch_size  # 使用配置中的 prm_batch_size
        original_batch_size_for_call = batch_size
        
        while current_batch_size >= 1:
            try:
                logger.debug(f"Attempting Qwen_2_5_Math.score with batch_size: {current_batch_size}")
                return self._score_batched(questions, outputs, batch_size=current_batch_size)
            except torch.cuda.OutOfMemoryError as e:
                logger.warning(
                    f"CUDA out of memory in Qwen_2_5_Math.score with batch_size {current_batch_size}. "
                    f"Attempting to reduce batch size. Error: {e}"
                )
                torch.cuda.empty_cache()  # 清理緩存記憶體

                if current_batch_size == 1:
                    logger.error(
                        f"CUDA out of memory even with batch_size 1. Re-raising original OOM error for questions: {questions[:1]}..."
                    )
                    raise e

                # 減少批次大小
                if current_batch_size <= 2:
                    current_batch_size = 1
                else:
                    current_batch_size //= 2
                
                logger.info(f"Retrying Qwen_2_5_Math.score with new batch_size: {current_batch_size}")
            except Exception as e:
                logger.error(f"An unexpected error occurred in Qwen_2_5_Math.score with batch_size {current_batch_size}: {e}")
                raise e

        # 備用錯誤處理
        logger.error(
            f"Failed to score batch even after reducing batch_size to 1. "
            f"Original batch_size was {original_batch_size_for_call}"
        )
        raise RuntimeError(f"Unable to score batch in Qwen_2_5_Math after reducing batch_size to 1.")

    def _score_batched(
        self, questions: list[str], outputs: list[list[str]], batch_size: int = 8
    ) -> list[list[float]]:
        all_scores = []

        for question, answers in zip(questions, outputs):
            question_scores = []
            
            # 按批次處理 answers
            for i in range(0, len(answers), batch_size):
                answers_batch = answers[i:i + batch_size]
                processed_responses = []
                
                for answer in answers_batch:
                    messages = [
                        {
                            "role": "system",
                            "content": "Please reason step by step, and put your final answer within \\boxed{}.",
                        },
                        {"role": "user", "content": question},
                        {
                            "role": "assistant",
                            "content": answer.replace("\n\n", "<extra_0>") + "<extra_0>",
                        },
                    ]
                    conversation_str = self.tokenizer.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=False
                    )
                    processed_responses.append(conversation_str)

                input_ids = self.tokenizer(
                    processed_responses, return_tensors="pt", padding=True, truncation=True
                )["input_ids"].to(self.model.device)

                with torch.no_grad():
                    model_outputs = self.model(input_ids=input_ids)

                step_sep_id = self.tokenizer.encode("<extra_0>")[0]
                token_masks = input_ids == step_sep_id
                batch_step_rewards = self.make_step_rewards(model_outputs[0], token_masks)
                question_scores.extend(batch_step_rewards)
                
                # 清理記憶體
                del input_ids, model_outputs, token_masks
                torch.cuda.empty_cache()
            
            all_scores.append(question_scores)

        return all_scores

    @staticmethod
    def make_step_rewards(logits, token_masks):
        probabilities = F.softmax(logits, dim=-1)
        probabilities = probabilities * token_masks.unsqueeze(
            -1
        )  # bs, seq_len, num_labels

        all_scores_res = []
        for i in range(probabilities.size(0)):
            sample = probabilities[i]  # seq_len, num_labels
            positive_probs = sample[sample != 0].view(-1, 2)[
                :, 1
            ]  # valid_tokens, num_labels
            all_scores_res.append(positive_probs.cpu().tolist())

        return all_scores_res


class SkyworkO1_1_5B(SkyworkO1):
    def load_model_and_tokenizer(
        self, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        prm_model_path = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
        return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs)


class SkyworkO1_7B(SkyworkO1):
    def load_model_and_tokenizer(
        self, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        prm_model_path = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B"
        return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs)


class Qwen_2_5_Math_7B(Qwen_2_5_Math):
    def load_model_and_tokenizer(
        self, **model_kwargs
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        prm_model_path = "/path/to/models/qwen2.5-math-prm-7b"
        return Qwen_2_5_Math._load_model_and_tokenizer(prm_model_path, **model_kwargs)


def load_prm(config: Config) -> PRM:
    if config.prm_path == "/path/to/models/llama-3.1-8b-prm-deepseek-data":
        return RLHFFlow(config)

    # if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B":
    #     return SkyworkO1_1_5B(config)

    # if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B":
    #     return SkyworkO1_7B(config)

    if config.prm_path == "/path/to/models/qwen2.5-math-prm-7b":
        return Qwen_2_5_Math_7B(config)

    raise NotImplementedError(f"PRM {config.prm_path} not implemented")
