from abc import ABC, abstractmethod
from typing import List
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
from .config import reward_skywork_path, reward_shepherd_path, reward_armorm_path, reward_grm_path, reward_skyworko1_path, reward_genprm_path
from .model import ModelWrapper
import torch
from .io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
from .prm_model import PRM_MODEL
from collections import defaultdict
import random 
import numpy as np
from utils.load_data import extract_answer
from vllm import LLM, SamplingParams
import math
import re

vanilla_prompt = """You are a helpful assistant in evaluating the quality of the outputs for a given instruction. 
Your goal is to select the best output for the given instruction. 
Select the Output (a) or Output (b) that is better for the given instruction. 
The two outputs are generated by two different AI chatbots respectively. 
Do NOT provide any explanation for your choice.
"""

select_reward_instruction = """You are a math teacher in evaluating the quality of the outputs for a given question. 
Your task is to review and critique different solutions and select the best one from them. 
Please do NOT provide any explanation for your choice. You should only put the index of the best solution in the format '\\boxed{index}'.
"""
select_reward_prompt = """Question:
{question}
Solutions:
{cot}
"""
select_solution_template = """Solution {idx}: 
{content}
"""


llm_reward_instruction = """The following is a math problem and a solution (split into paragraphs, enclosed with tags and indexed from 0): 
"""

llm_reward_prompt = """[Math Problem] 

{question}

[Solution] 

{cot}

Your task is to review and critique the solution paragraph by paragraph. Once you identify an error in a paragraph, return the index of the paragraph where the earliest error occurs. Otherwise, return the index of -1 (which typically denotes "not found"). 

Please do NOT provide any explanation for your choice. You should only put your final answer (i.e., the index) in """

step_template = """<paragraph_{idx}> 
{content}
</paragraph_{idx}> 
"""

grm_reward_prompt = """Question:
{question}
Solution:
{cot}
Verification: 
Is the answer correct (Yes/No)?"""

grm_reward_instruction = """You are a math teacher in evaluating the quality of the outputs for a given question. 
At the end of the Solution verification, when you give your final grade, write it in the form "Yes/No".
"""


genprm_reward_instruction = "You are a math teacher. Your task is to review and critique the paragraphs in solution directly. Output your judgement in the format of `\\boxed{Yes}` if the paragraph is correct, or `\\boxed{No}` if the paragraph is incorrect."
genprm_output_template = "<output>\n**Judgement**: $\\boxed"
genprm_analyze_template = "<analyze>\nLet's analyze the last paragraph step by step: "
genprm_verify_template = "<verify>\nLet's use python code to find any potential error:\n```python\n"


class Reward(ABC):
    def __init__(self, model_path: str, remote: bool, dataset: str):
        if remote:
            self.model_path = model_path.format(dir='usercache')
        else:
            self.model_path = model_path.format(dir='publiccache')
        self.model = None 
        self.tokenizer = None 
        self.dataset = dataset
        self.remote = remote
        self._initialize()

    @abstractmethod
    def _initialize(self):
        pass

    @abstractmethod
    def score(self, question:str, responses:List[str], **kwargs) -> List[float]:
        """计算分数的抽象方法"""
        pass
    
    
    def find_most_confident_answer(self, question, completions: List[str]):
        """Returns the most confident answer, its completion, its id in the input list, and its confidence."""
        if completions is None or len(completions) == 0:
            return None, None, None, None
        
        answers = [extract_answer(c, self.dataset) for c in completions]
        scores = self.score(question, completions, step_reward=False, agg='last')
        confidence = max(scores)
        max_index = scores.index(confidence)
        # max_index, confidence = max(enumerate(scores), key=lambda x: x[1])
        # print(max(enumerate(scores), key=lambda x: x[1]))
        most_confident_answer = answers[max_index]
        most_confident_completion = completions[max_index]
        # print(most_confident_answer)
        # print(completions)
        # assert confidence > 0
        return (
            most_confident_answer,
            most_confident_completion,
            confidence,
        )
    
    def stochastic_select_response(self, completion2score, completions):
        sorted_completions = sorted(completion2score.items(), key=lambda x: x[1], reverse=True)[:1]
        completions, scores = zip(*sorted_completions)
        total_score = sum(scores)
        try:
            probabilities = [score / total_score for score in scores]
            sampled_completion = random.choices(completions, weights=probabilities, k=1)[0]
        except:
            sampled_completion = random.choices(completions, k=1)[0]
        confidence = completion2score[sampled_completion]
        most_confident_answer = extract_answer(sampled_completion, self.dataset)
        id_of_most_confident = completions.index(sampled_completion)
        return most_confident_answer, sampled_completion, id_of_most_confident, confidence


# 示例子类1：基于特定名称类型进行初始化
class SkyworkReward(Reward):
    def _initialize(self):
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map='auto',
            num_labels=1,
        )
        self.tokenizer =  AutoTokenizer.from_pretrained(self.model_path)
        
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            content = [{"role": "user", "content": question}, {"role": "assistant", "content": res}]
            input = self.tokenizer.apply_chat_template(content, tokenize=True, return_tensors="pt").to(self.model.device)
            with torch.no_grad():
                output = self.model(input).logits[0][0].item()
            scores.append(output)
        return scores


class SkyworkO1Reward(Reward):
    def _initialize(self):
        self.tokenizer =  AutoTokenizer.from_pretrained(self.model_path)
        self.model = PRM_MODEL.from_pretrained(self.model_path,
                                            #    torch_dtype=torch.bfloat16, 
                                               device_map='auto').eval()
        
    def score(self, question, responses, **kwargs):
        scores = []
        for response in responses:
            processed_data = [prepare_input(question, response, tokenizer=self.tokenizer, step_token=".")]
            input_ids, steps, reward_flags = zip(*processed_data)
            input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, self.tokenizer.pad_token_id)
            _, _, rewards = self.model(input_ids=input_ids.to(self.model.pretrained_model.device), attention_mask=attention_mask.to(self.model.pretrained_model.device), return_probs=True)
            score = derive_step_rewards(rewards, reward_flags)
            scores += score
        if kwargs['step_reward']: 
            return scores
        else:
            agg = kwargs.get('agg')
            if agg == 'last':
                return [score[-1] for score in scores]
            elif agg == 'min':
                return [np.min(np.array(score)) for score in scores]
            else:
                return [np.prod(np.array(score)) for score in scores]

    
    def stochastic_find_most_confident_answer(
        self,
        question,
        completions: List[str]
    ):
        if not completions or len(completions) == 0:
            return None, None, None, None
        scores = self.score(question, completions, step_reward=False)
        completion2score = {completions[i]:scores[i] for i in range(len(completions))}

        
        most_confident_answer, sampled_completion, id_of_most_confident, confidence = self.stochastic_select_response(
            completion2score, completions
        )
        # print(confidence)
        return most_confident_answer, sampled_completion, id_of_most_confident, confidence
    
class GenPRMReward(Reward):
    def _initialize(self):
        self.tokenizer =  AutoTokenizer.from_pretrained(self.model_path)
        self.model = LLM(self.model_path)
        self.sampling_params = SamplingParams(
            temperature=0.6,
            top_p=0.95,
            max_tokens=2048,
            include_stop_str_in_output=True,
            logprobs=20,
            top_k=20,
            repetition_penalty=1.0
        )
    def _update_params(self, stop_tokens):
        self.sampling_params = SamplingParams(
            stop = stop_tokens,
            temperature=0.6,
            top_p=0.95,
            max_tokens=2048,
            include_stop_str_in_output=True,
            logprobs=20,
            top_k=20,
            repetition_penalty=1.0
        )
    def get_reward_score(self, out):
        '''calculate the reward score'''
        generated_text = out.text
        logprobs = out.logprobs
        tokens = out.token_ids
        token_logprobs = logprobs

        # find the position of Yes/No token
        boxed_match = re.search(r'(Yes|No)\}', generated_text, re.IGNORECASE)
        yes_token = self.tokenizer.encode('Yes')[-1]
        no_token = self.tokenizer.encode('No')[-1]

        if boxed_match:
            decision = boxed_match.group(1).capitalize()
            if decision == "Yes":
                yes_index = len(tokens) - 1 - tokens[::-1].index(yes_token)
                yes_logprob = token_logprobs[yes_index][yes_token].logprob
                # convert logprob to probability
                yes_prob = math.exp(yes_logprob)  # e^log(prob) = prob

                # find the position of 'No' token
                try:
                    no_logprob = token_logprobs[yes_index][no_token].logprob
                    no_prob = math.exp(no_logprob)
                except KeyError:
                    # set 'No' probability to the minimum logprob of the remaining 4 logprobs
                    min_logprob = min(v.logprob for k, v in token_logprobs[yes_index].items())
                    no_prob = math.exp(min_logprob)

                # calculate softmax value
                softmax_denominator = yes_prob + no_prob
                if softmax_denominator == 0:
                    softmax_yes = 0.5  # in case of division by zero, assign neutral score
                else:
                    softmax_yes = yes_prob / softmax_denominator

                return softmax_yes

            elif decision == "No":
                no_index = len(tokens) - 1 - tokens[::-1].index(no_token)
                no_logprob = token_logprobs[no_index][no_token].logprob
                # convert logprob to probability
                no_prob = math.exp(no_logprob)  # e^log(prob) = prob

                # find the position of 'Yes' token
                try:
                    yes_logprob = token_logprobs[no_index][yes_token].logprob
                    yes_prob = math.exp(yes_logprob)
                except KeyError:
                    # set 'Yes' probability to the minimum logprob of the remaining 4 logprobs
                    min_logprob = min(v.logprob for k, v in token_logprobs[no_index].items())
                    yes_prob = math.exp(min_logprob)

                # calculate softmax value
                softmax_denominator = yes_prob + no_prob
                if softmax_denominator == 0:
                    softmax_yes = 0.5  # in case of division by zero, assign neutral score
                else:
                    softmax_yes = yes_prob / softmax_denominator

                return softmax_yes
        else:
            # return neutral score if no decision found
            return 0.5
    
    def build_prompt(self, messages):
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        if prompt.endswith(f"{self.tokenizer.eos_token}\n"):
            prompt = prompt[:-len(f"{self.tokenizer.eos_token}\n")]
        elif prompt.endswith(self.tokenizer.eos_token):
            prompt = prompt[:-len(self.tokenizer.eos_token)]
        return prompt
    
    
    def score(self, question, responses, **kwargs):
        scores = []
        for response in responses:
            data_input = response.split('\n\n')
            data_input[0] = question + '\n' + data_input[0]
            if data_input and data_input[-1] == '':
                data_input.pop()
            
            message = {
                'conversation': [
                    {'role': 'system', 'content': genprm_reward_instruction}
                ]
            } 
            for j1 in range(len(data_input)):
                line = {'role': 'user', 'content': data_input[j1]}
                message['conversation'].append(line)
                line = {'content': '', 'role': 'assistant'}
                message['conversation'].append(line)
            
            conversation = message['conversation']
            prompt = self.build_prompt(conversation)
            self._update_params(stop_tokens=['</analyze>\n'])
            output1 = self.model.generate(prompt + genprm_analyze_template, sampling_params=self.sampling_params, use_tqdm=False)[0].outputs[0]
            cur_prompt = genprm_analyze_template + output1.text + genprm_output_template
            self._update_params(stop_tokens=['</output>\n']) 
            output2 = self.model.generate(prompt + cur_prompt, sampling_params=self.sampling_params, use_tqdm=False)[0].outputs[0]
            reward_score = self.get_reward_score(output2)
            scores.append(reward_score)
        return scores
    
    def stochastic_find_most_confident_answer(
        self,
        question,
        completions: List[str]
    ):
        if not completions or len(completions) == 0:
            return None, None, None, None
        scores = self.score(question, completions, step_reward=False)
        completion2score = {completions[i]:scores[i] for i in range(len(completions))}

        
        most_confident_answer, sampled_completion, id_of_most_confident, confidence = self.stochastic_select_response(
            completion2score, completions
        )
        # print(confidence)
        return most_confident_answer, sampled_completion, id_of_most_confident, confidence

    


class ShepherdReward(Reward):
    def _initialize(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            device_map='auto'
            ).eval()
        good_token = '+'
        bad_token = '-'
        step_tag = 'ки'
        self.candidate_tokens = self.tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
        self.step_tag_id = self.tokenizer.encode(f"{step_tag}")[-1] # 12902 
        
    def _add_step_tag(self, response):
        alter_res = ""
        cots = response.split('\n')
        for cot in cots:
            if not cot:
                continue
            for token in cot:
                if token == '.':
                    alter_res += '. ки\n'
                else:
                    alter_res += token
        return alter_res
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            tag_res = self._add_step_tag(res)
            content = f"{question} {tag_res}"
            input = torch.tensor([self.tokenizer.encode(content)]).to(self.model.device)

            with torch.no_grad():
                logits = self.model(input).logits[:,:,self.candidate_tokens]
                score = logits.softmax(dim=-1)[:,:,0]
                step_score = score[input == self.step_tag_id].to('cpu') 
            scores.append(step_score)
        if kwargs['step_reward']:  
            return scores
        else:
            agg = kwargs.get('agg')
            if agg == 'last':
                return [score[-1] for score in scores]
            elif agg == 'min':
                return [np.min(np.array(score)) for score in scores]
            else:
                return [np.prod(np.array(score)) for score in scores]


class ArmormReward(Reward):
    def _initialize(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            device_map='auto',
            trust_remote_code=True
            )
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            messages = [{"role": "user", "content": question},{"role": "assistant", "content": res}]
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                return_tensors="pt",
                padding=True,
            ).to(self.model.device)
            with torch.no_grad():
                output = self.model(input_ids)
                score = output.score.float().item()
            scores.append(score)
        return scores

class GRMReward(Reward):
    def _initialize(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            device_map='auto',
            trust_remote_code=True,
            torch_dtype=torch.float16, 
            )
    
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            message = [{"role": "user", "content": question},{"role": "assistant", "content": res}]
            message_template = self.tokenizer.apply_chat_template(message, tokenize=False)

            kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
            tokens = self.tokenizer.encode_plus(message_template, **kwargs)

            with torch.no_grad():
                reward_tensor = self.model(tokens["input_ids"][0].view(1,-1).to(self.model.device), attention_mask=tokens["attention_mask"][0].view(1,-1).to(self.model.device))[0]
                reward = reward_tensor.cpu().detach().item()
            scores.append(reward)
        return scores


class LLMReward(Reward):
    def _initialize(self):
        self.model = ModelWrapper(self.model_path)
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            res_list = res.split('.')
            res_list = [res.strip() for res in res_list if res]
            solution = ('.').join([step_template.format(idx=idx, content=res_list[idx]) for idx in range(len(res_list))])
            input = [
                {"role": "system", "content": llm_reward_instruction},
                {"role": "user", "content": llm_reward_prompt.format(question=question, cot=solution) + " \\boxed{}."}
            ]
            response = self.model.generate(input, max_tokens=200)
            # print(response)
            score = extract_answer(response, 'math')
            # print(score)
            scores.append(score)
        return scores


class SelectReward(Reward):
    def _initialize(self):
        self.model_path = self.model_path.split('-')[-1]
        self.model = ModelWrapper(self.model_path, self.remote)
    
    def score(self, question, responses, **kwargs):
        cot = "".join([select_solution_template.format(idx=i, content=responses[i]) for i in range(len(responses))])
        input = [
            {"role": "system", "content": select_reward_instruction},
            {"role": "user", "content": select_reward_prompt.format(question=question, cot=cot)}
        ]
        response = self.model.generate(input, max_tokens=200, greedy=True)[0]
        # print(response)
        best_index = extract_answer(response, self.dataset)
        # print(best_index)
        # print(score)
        scores = [0] * len(responses)
        scores[int(best_index)] = 1
        return scores


class GenReward(Reward):
    def _initialize(self):
        self.model_path = self.model_path.split('-')[-1]
        self.model = ModelWrapper(self.model_path, self.remote)
    
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            if self.model.is_mistral or self.model.is_gemma:
                input = [
                    {"role": "user", "content": grm_reward_instruction + grm_reward_prompt.format(question=question, cot=res)}
                ] 
            else:
                input = [
                    {"role": "system", "content": grm_reward_instruction},
                    {"role": "user", "content": grm_reward_prompt.format(question=question, cot=res)}
                ]   
            input += [{"role": "assistant", "content": ''}]
            input = self.model.tokenizer.apply_chat_template(input, tokenize=False)
            text = '<'.join(input.split('<')[:-1]) + 'Yes'
            score = self.model.cal_logits(text, 'Yes')
            scores.append(score)
        return scores
    

class SelfReward(Reward):
    def _initialize(self):
        self.model_path = self.model_path.split('-')[-1]
        self.model = ModelWrapper(self.model_path, self.remote)
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            content = [{"role": "user", "content": question}, {"role": "assistant", "content": ''}]
            input = self.model.tokenizer.apply_chat_template(content, tokenize=False)
            text = '<'.join(input.split('<')[:-1]) + res
            score = self.model.cal_prob(text, res)
            scores.append(score)
        return scores
    
    def find_most_confident_answer(self, question, completions: List[str]):
        """Returns the most confident answer, its completion, its id in the input list, and its confidence."""
        if completions is None or len(completions) == 0:
            return None, None, None, None
        
        answers = [extract_answer(c, self.dataset) for c in completions]
        pred = max(answers, key=answers.count)
        confidence = answers.count(pred) / len(answers)
        max_index = answers.index(pred)
        # max_index, confidence = max(enumerate(scores), key=lambda x: x[1])
        # print(max(enumerate(scores), key=lambda x: x[1]))
        most_confident_answer = pred
        most_confident_completion = completions[max_index]
        # print(most_confident_answer)
        # print(completions)
        assert confidence > 0
        return (
            most_confident_answer,
            most_confident_completion,
            confidence,
        )

# class SCReward(Reward):
    


# 工厂方法，根据name动态返回具体子类的实例
def reward_factory(name: str, remote: bool, dataset: str) -> Reward:
    if name == 'skywork':
        return SkyworkReward(reward_skywork_path, remote, dataset)
    elif name == 'shepherd':
        return ShepherdReward(reward_shepherd_path, remote, dataset)
    elif name == 'armorm':
        return ArmormReward(reward_armorm_path, remote, dataset)
    elif name =='genprm':
        return GenPRMReward(reward_genprm_path, remote, dataset)
    elif name =='grm':
        return GRMReward(reward_grm_path, remote, dataset)
    elif name == 'skyworko1':
        return SkyworkO1Reward(reward_skyworko1_path, remote, dataset)
    elif name.startswith('self'):
        return SelfReward(name, remote, dataset)
    elif name.startswith('select'):
        return SelectReward(name, remote, dataset)
    elif name.startswith('gen'):
        return GenReward(name, remote, dataset)
    else:
        return LLMReward(name, remote, dataset)