from verl import DataProto
import torch

from verl.utils.reward_score import gsm8k, math, custom_math


def _select_rm_score_fn(data_source):
    if data_source == 'openai/gsm8k':
        return gsm8k.compute_score
    elif any(e in data_source for e in ['custom_math', 'aime']):
        return custom_math.compute_score
    elif any(e in data_source for e in ['lighteval/MATH']):
        return math.compute_score
    else:
        raise NotImplementedError
    

class QwenRewardManager:
    """The reward manager.
    """

    def __init__(self, tokenizer, num_examine, compute_score=None) -> None:
        self.tokenizer = tokenizer
        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console

    def __call__(self, data: DataProto):
        """We will expand this function gradually based on the available datasets"""

        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
        if 'rm_scores' in data.batch.keys():
            return data.batch['rm_scores']

        reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)

        if all(any(e in d.non_tensor_batch['data_source'] for e in ['custom_math', 'aime']) for d in data):
            sequences_strs = []
            ground_truths = []
            valid_response_lengths = []

            for i in range(len(data)):
                data_item = data[i]

                prompt_ids = data_item.batch['prompts']

                prompt_length = prompt_ids.shape[-1]

                valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
                valid_prompt_ids = prompt_ids[-valid_prompt_length:]

                response_ids = data_item.batch['responses']
                valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
                valid_response_ids = response_ids[:valid_response_length]

                # decode
                sequences = torch.cat((valid_prompt_ids, valid_response_ids))
                sequences_str = self.tokenizer.decode(sequences)

                ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']

                sequences_strs.append(sequences_str)
                ground_truths.append(ground_truth)
                valid_response_lengths.append(valid_response_length)

            scores = custom_math.compute_score(solution_strs=sequences_strs, ground_truths=ground_truths)
            for i in range(len(data)):
                reward_tensor[i, valid_response_lengths[i] - 1] = scores[i]

        else:
            assert all(not any(e in d.non_tensor_batch['data_source'] for e in ['custom_math', 'aime24']) for d in data)
            already_print_data_sources = {}

            for i in range(len(data)):
                data_item = data[i]  # DataProtoItem

                prompt_ids = data_item.batch['prompts']

                prompt_length = prompt_ids.shape[-1]

                valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
                valid_prompt_ids = prompt_ids[-valid_prompt_length:]

                response_ids = data_item.batch['responses']
                valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
                valid_response_ids = response_ids[:valid_response_length]

                # decode
                sequences = torch.cat((valid_prompt_ids, valid_response_ids))
                sequences_str = self.tokenizer.decode(sequences)

                ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']

                # select rm_score
                data_source = data_item.non_tensor_batch['data_source']
                compute_score_fn = _select_rm_score_fn(data_source)

                score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
                reward_tensor[i, valid_response_length - 1] = score

                if data_source not in already_print_data_sources:
                    already_print_data_sources[data_source] = 0

                if already_print_data_sources[data_source] < self.num_examine:
                    already_print_data_sources[data_source] += 1
                    print(sequences_str)

        return reward_tensor