from abc import ABC, abstractmethod
from typing import List
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
from .config import reward_skywork_path, reward_shepherd_path, reward_armorm_path, reward_grm_path, reward_skyworko1_path
from .model import ModelWrapper
import torch
from .io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
from .prm_model import PRM_MODEL
from collections import defaultdict
import random 
import numpy as np
from utils.load_data import extract_answer

vanilla_prompt = """You are a helpful assistant in evaluating the quality of the outputs for a given instruction. 
Your goal is to select the best output for the given instruction. 
Select the Output (a) or Output (b) that is better for the given instruction. 
The two outputs are generated by two different AI chatbots respectively. 
Do NOT provide any explanation for your choice.
"""

select_reward_instruction = """You are a math teacher in evaluating the quality of the outputs for a given question. 
Your task is to review and critique different solutions and select the best one from them. 
Please do NOT provide any explanation for your choice. You should only put the index of the best solution in the format '\\boxed{index}'.
"""
select_reward_prompt = """Question:
{question}
Solutions:
{cot}
"""
select_solution_template = """Solution {idx}: 
{content}
"""


llm_reward_instruction = """The following is a math problem and a solution (split into paragraphs, enclosed with tags and indexed from 0): 
"""

llm_reward_prompt = """[Math Problem] 

{question}

[Solution] 

{cot}

Your task is to review and critique the solution paragraph by paragraph. Once you identify an error in a paragraph, return the index of the paragraph where the earliest error occurs. Otherwise, return the index of -1 (which typically denotes "not found"). 

Please do NOT provide any explanation for your choice. You should only put your final answer (i.e., the index) in """

step_template = """<paragraph_{idx}> 
{content}
</paragraph_{idx}> 
"""

grm_reward_prompt = """Question:
{question}
Solution:
{cot}
Verification: 
Is the answer correct (Yes/No)?"""

grm_reward_instruction = """You are a math teacher in evaluating the quality of the outputs for a given question. 
At the end of the Solution verification, when you give your final grade, write it in the form "Yes/No".
"""


class Reward(ABC):
    def __init__(self, model_path: str, remote: bool, dataset: str):
        if remote:
            self.model_path = model_path.format(dir='usercache')
        else:
            self.model_path = model_path.format(dir='publiccache')
        self.model = None 
        self.tokenizer = None 
        self.dataset = dataset
        self.remote = remote
        self._initialize()

    @abstractmethod
    def _initialize(self):
        pass

    @abstractmethod
    def score(self, question:str, responses:List[str], **kwargs) -> List[float]:
        """计算分数的抽象方法"""
        pass
    
    
    def find_most_confident_answer(self, question, completions: List[str]):
        """Returns the most confident answer, its completion, its id in the input list, and its confidence."""
        if completions is None or len(completions) == 0:
            return None, None, None, None
        
        answers = [extract_answer(c, self.dataset) for c in completions]
        scores = self.score(question, completions, step_reward=False, agg='last')
        confidence = max(scores)
        max_index = scores.index(confidence)
        # max_index, confidence = max(enumerate(scores), key=lambda x: x[1])
        # print(max(enumerate(scores), key=lambda x: x[1]))
        most_confident_answer = answers[max_index]
        most_confident_completion = completions[max_index]
        # print(most_confident_answer)
        # print(completions)
        # assert confidence > 0
        return (
            most_confident_answer,
            most_confident_completion,
            confidence,
        )
    
    def stochastic_select_response(self, completion2score, completions):
        sorted_completions = sorted(completion2score.items(), key=lambda x: x[1], reverse=True)[:1]
        completions, scores = zip(*sorted_completions)
        total_score = sum(scores)
        try:
            probabilities = [score / total_score for score in scores]
            sampled_completion = random.choices(completions, weights=probabilities, k=1)[0]
        except:
            sampled_completion = random.choices(completions, k=1)[0]
        confidence = completion2score[sampled_completion]
        most_confident_answer = extract_answer(sampled_completion, self.dataset)
        id_of_most_confident = completions.index(sampled_completion)
        return most_confident_answer, sampled_completion, id_of_most_confident, confidence


# 示例子类1：基于特定名称类型进行初始化
class SkyworkReward(Reward):
    def _initialize(self):
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map='auto',
            num_labels=1,
        )
        self.tokenizer =  AutoTokenizer.from_pretrained(self.model_path)
        
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            content = [{"role": "user", "content": question}, {"role": "assistant", "content": res}]
            input = self.tokenizer.apply_chat_template(content, tokenize=True, return_tensors="pt").to(self.model.device)
            with torch.no_grad():
                output = self.model(input).logits[0][0].item()
            scores.append(output)
        return scores


class SkyworkO1Reward(Reward):
    def _initialize(self):
        self.tokenizer =  AutoTokenizer.from_pretrained(self.model_path)
        self.model = PRM_MODEL.from_pretrained(self.model_path,
                                            #    torch_dtype=torch.bfloat16, 
                                               device_map='auto').eval()
        
    def score(self, question, responses, **kwargs):
        scores = []
        for response in responses:
            processed_data = [prepare_input(question, response, tokenizer=self.tokenizer, step_token=".")]
            input_ids, steps, reward_flags = zip(*processed_data)
            input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, self.tokenizer.pad_token_id)
            _, _, rewards = self.model(input_ids=input_ids.to(self.model.pretrained_model.device), attention_mask=attention_mask.to(self.model.pretrained_model.device), return_probs=True)
            score = derive_step_rewards(rewards, reward_flags)
            scores += score
        if kwargs['step_reward']: 
            return scores
        else:
            agg = kwargs.get('agg')
            if agg == 'last':
                return [score[-1] for score in scores]
            elif agg == 'min':
                return [np.min(np.array(score)) for score in scores]
            else:
                return [np.prod(np.array(score)) for score in scores]

    
    def stochastic_find_most_confident_answer(
        self,
        question,
        completions: List[str]
    ):
        if not completions or len(completions) == 0:
            return None, None, None, None
        scores = self.score(question, completions, step_reward=False)
        completion2score = {completions[i]:scores[i] for i in range(len(completions))}

        
        most_confident_answer, sampled_completion, id_of_most_confident, confidence = self.stochastic_select_response(
            completion2score, completions
        )
        # print(confidence)
        return most_confident_answer, sampled_completion, id_of_most_confident, confidence
    


class ShepherdReward(Reward):
    def _initialize(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            device_map='auto'
            ).eval()
        good_token = '+'
        bad_token = '-'
        step_tag = 'ки'
        self.candidate_tokens = self.tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
        self.step_tag_id = self.tokenizer.encode(f"{step_tag}")[-1] # 12902 
        
    def _add_step_tag(self, response):
        alter_res = ""
        cots = response.split('\n')
        for cot in cots:
            if not cot:
                continue
            for token in cot:
                if token == '.':
                    alter_res += '. ки\n'
                else:
                    alter_res += token
        return alter_res
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            tag_res = self._add_step_tag(res)
            content = f"{question} {tag_res}"
            input = torch.tensor([self.tokenizer.encode(content)]).to(self.model.device)

            with torch.no_grad():
                logits = self.model(input).logits[:,:,self.candidate_tokens]
                score = logits.softmax(dim=-1)[:,:,0]
                step_score = score[input == self.step_tag_id].to('cpu') 
            scores.append(step_score)
        if kwargs['step_reward']:  
            return scores
        else:
            agg = kwargs.get('agg')
            if agg == 'last':
                return [score[-1] for score in scores]
            elif agg == 'min':
                return [np.min(np.array(score)) for score in scores]
            else:
                return [np.prod(np.array(score)) for score in scores]


class ArmormReward(Reward):
    def _initialize(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            device_map='auto',
            trust_remote_code=True
            )
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            messages = [{"role": "user", "content": question},{"role": "assistant", "content": res}]
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                return_tensors="pt",
                padding=True,
            ).to(self.model.device)
            with torch.no_grad():
                output = self.model(input_ids)
                score = output.score.float().item()
            scores.append(score)
        return scores

class GRMReward(Reward):
    def _initialize(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            device_map='auto',
            trust_remote_code=True,
            torch_dtype=torch.float16, 
            )
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            message = [{"role": "user", "content": question},{"role": "assistant", "content": res}]
            message_template = self.tokenizer.apply_chat_template(message, tokenize=False)

            kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
            tokens = self.tokenizer.encode_plus(message_template, **kwargs)

            with torch.no_grad():
                reward_tensor = self.model(tokens["input_ids"][0].view(1,-1).to(self.model.device), attention_mask=tokens["attention_mask"][0].view(1,-1).to(self.model.device))[0]
                reward = reward_tensor.cpu().detach().item()
            scores.append(reward)
        return scores


class LLMReward(Reward):
    def _initialize(self):
        self.model = ModelWrapper(self.model_path)
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            res_list = res.split('.')
            res_list = [res.strip() for res in res_list if res]
            solution = ('.').join([step_template.format(idx=idx, content=res_list[idx]) for idx in range(len(res_list))])
            input = [
                {"role": "system", "content": llm_reward_instruction},
                {"role": "user", "content": llm_reward_prompt.format(question=question, cot=solution) + " \\boxed{}."}
            ]
            response = self.model.generate(input, max_tokens=200)
            # print(response)
            score = extract_answer(response, 'math')
            # print(score)
            scores.append(score)
        return scores


class SelectReward(Reward):
    def _initialize(self):
        self.model_path = self.model_path.split('-')[-1]
        self.model = ModelWrapper(self.model_path, self.remote)
    
    def score(self, question, responses, **kwargs):
        cot = "".join([select_solution_template.format(idx=i, content=responses[i]) for i in range(len(responses))])
        input = [
            {"role": "system", "content": select_reward_instruction},
            {"role": "user", "content": select_reward_prompt.format(question=question, cot=cot)}
        ]
        response = self.model.generate(input, max_tokens=200, greedy=True)[0]
        # print(response)
        best_index = extract_answer(response, self.dataset)
        # print(best_index)
        # print(score)
        scores = [0] * len(responses)
        scores[int(best_index)] = 1
        return scores


class GenReward(Reward):
    def _initialize(self):
        self.model_path = self.model_path.split('-')[-1]
        self.model = ModelWrapper(self.model_path, self.remote)
    
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            if self.model.is_mistral or self.model.is_gemma:
                input = [
                    {"role": "user", "content": grm_reward_instruction + grm_reward_prompt.format(question=question, cot=res)}
                ] 
            else:
                input = [
                    {"role": "system", "content": grm_reward_instruction},
                    {"role": "user", "content": grm_reward_prompt.format(question=question, cot=res)}
                ]   
            input += [{"role": "assistant", "content": ''}]
            input = self.model.tokenizer.apply_chat_template(input, tokenize=False)
            text = '<'.join(input.split('<')[:-1]) + 'Yes'
            score = self.model.cal_logits(text, 'Yes')
            scores.append(score)
        return scores
    

class SelfReward(Reward):
    def _initialize(self):
        self.model_path = self.model_path.split('-')[-1]
        self.model = ModelWrapper(self.model_path, self.remote)
    
    def score(self, question, responses, **kwargs):
        scores = []
        for res in responses:
            content = [{"role": "user", "content": question}, {"role": "assistant", "content": ''}]
            input = self.model.tokenizer.apply_chat_template(content, tokenize=False)
            text = '<'.join(input.split('<')[:-1]) + res
            score = self.model.cal_prob(text, res)
            scores.append(score)
        return scores
    
    def find_most_confident_answer(self, question, completions: List[str]):
        """Returns the most confident answer, its completion, its id in the input list, and its confidence."""
        if completions is None or len(completions) == 0:
            return None, None, None, None
        
        answers = [extract_answer(c, self.dataset) for c in completions]
        pred = max(answers, key=answers.count)
        confidence = answers.count(pred) / len(answers)
        max_index = answers.index(pred)
        # max_index, confidence = max(enumerate(scores), key=lambda x: x[1])
        # print(max(enumerate(scores), key=lambda x: x[1]))
        most_confident_answer = pred
        most_confident_completion = completions[max_index]
        # print(most_confident_answer)
        # print(completions)
        assert confidence > 0
        return (
            most_confident_answer,
            most_confident_completion,
            confidence,
        )

# class SCReward(Reward):
    


# 工厂方法，根据name动态返回具体子类的实例
def reward_factory(name: str, remote: bool, dataset: str) -> Reward:
    if name == 'skywork':
        return SkyworkReward(reward_skywork_path, remote, dataset)
    elif name == 'shepherd':
        return ShepherdReward(reward_shepherd_path, remote, dataset)
    elif name == 'armorm':
        return ArmormReward(reward_armorm_path, remote, dataset)
    elif name =='grm':
        return GRMReward(reward_grm_path, remote, dataset)
    elif name == 'skyworko1':
        return SkyworkO1Reward(reward_skyworko1_path, remote, dataset)
    elif name.startswith('self'):
        return SelfReward(name, remote, dataset)
    elif name.startswith('select'):
        return SelectReward(name, remote, dataset)
    elif name.startswith('gen'):
        return GenReward(name, remote, dataset)
    else:
        return LLMReward(name, remote, dataset)