import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F

def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
    
    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i] # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
        non_zero_elements_list = positive_probs.cpu().tolist()
        all_scores_res.append(non_zero_elements_list)
    return all_scores_res

class Reward_Prm:
    def __init__(self, model_path, device):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,torch_dtype=torch.bfloat16)
        self.model.to(device)

    def __call__(self, system_prompt, user, answer):
        # Split response by empty lines and format as requested
        response = []
        for paragraph in answer.split('\n\n'):
            if paragraph.strip():  # Only add non-empty paragraphs
                response.append(paragraph.strip())

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user},
            {"role": "assistant", "content": "<extra_0>".join(response) + "<extra_0>"},
        ]
        conversation_str = self.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=False
        )
        input_ids = self.tokenizer.encode(
            conversation_str, 
            return_tensors="pt", 
        ).to(self.device)
        outputs = self.model(input_ids=input_ids)
        step_sep_id = self.tokenizer.encode("<extra_0>")[0]
        token_masks = (input_ids == step_sep_id)

        step_reward = make_step_rewards(outputs[0], token_masks)
        mean_reward = sum(step_reward[0]) / len(step_reward[0])
        return mean_reward
        
    def eval_split_reward(self, system_prompt, user, answer):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user},
            {"role": "assistant", "content": "<extra_0>".join(answer) + "<extra_0>"},
        ]
        conversation_str = self.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=False
        )
        input_ids = self.tokenizer.encode(
            conversation_str, 
            return_tensors="pt", 
        ).to(self.device)
        outputs = self.model(input_ids=input_ids)
        step_sep_id = self.tokenizer.encode("<extra_0>")[0]
        token_masks = (input_ids == step_sep_id)
        step_reward = make_step_rewards(outputs[0], token_masks)
        return step_reward[0]