

from verl import DataProto
import torch
from verl.utils.reward_score import gsm8k, math, multiply, countdown, kk, amc_aime, deepscaler
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
import json
from typing import Optional
import ray
import hydra
import random
import numpy as np

def _select_rm_score_fn(data_source):
    if data_source == 'openai/gsm8k':
        return gsm8k.compute_score
    elif data_source == 'DigitalLearningGmbH/MATH-lighteval' or data_source == "HuggingFaceH4/MATH-500":
        return math.compute_score
    elif "multiply" in data_source or "arithmetic" in data_source:
        return multiply.compute_score
    elif "countdown" in data_source:
        return countdown.compute_score
    elif "kk" in data_source:
        return kk.compute_score
    elif data_source == "amc" or data_source == "aime":
        return amc_aime.compute_score
    elif "dsr" in data_source or "deepscaler" in data_source or data_source == "":
        return deepscaler.rllm_reward_fn
    else:
        raise NotImplementedError


def _compute_confidence_ci(log_prob: torch.Tensor,
                           mask: torch.Tensor,
                           epsilon: float = 1e-8) -> torch.Tensor:

    masked_log_prob = log_prob * mask


    sequence_lengths = mask.sum(dim=-1)


    sum_log_prob = masked_log_prob.sum(dim=-1)


    mean_log_prob = sum_log_prob / (sequence_lengths + epsilon)


    ci = torch.exp(mean_log_prob)

    return ci
class RewardManager():


    def __init__(self, tokenizer, num_examine) -> None:
        self.tokenizer = tokenizer
        self.num_examine = num_examine

    def __call__(self, data: DataProto, save_analysis_path: Optional[str] = None):



        if 'rm_scores' in data.batch.keys():
            return data.batch['rm_scores']

        reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)


        responses = data.batch['responses']

        is_save_mode = save_analysis_path is not None


        confidences_list = [0.0] * len(data)
        if is_save_mode:
            response_length = responses.size(1)
            all_attention_mask = data.batch['attention_mask']
            response_only_mask = all_attention_mask[:, -response_length:]

            log_probs = data.batch['response_log_probs']
            confidences_tensor = _compute_confidence_ci(log_probs, response_only_mask)
            confidences_list = confidences_tensor.cpu().tolist()

        all_samples_data_for_json = []
        already_print_data_sources = {}

        for i in range(len(data)):
            data_item = data[i]

            prompt_ids = data_item.batch['prompts']

            prompt_length = prompt_ids.shape[-1]

            valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch['responses']
            valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]


            sequences = torch.cat((valid_prompt_ids, valid_response_ids))
            sequences_str = self.tokenizer.decode(sequences)

            ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']


            data_source = data_item.non_tensor_batch['data_source']
            compute_score_fn = _select_rm_score_fn(data_source)

            if "dsr" in data_source or "deepscaler" in data_source or data_source == "":
                score = compute_score_fn(data_source=data_source, llm_solution=sequences_str, ground_truth=ground_truth)
            else:
                score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
            reward_tensor[i, valid_response_length - 1] = score



            if is_save_mode:
                confidence = confidences_list[i]

                full_sample_for_json = {
                    "full_sequence": sequences_str,
                    "score": score,
                    "confidence": confidence,
                    "ground_truth": ground_truth,
                    "data_source": data_source
                }
                all_samples_data_for_json.append(full_sample_for_json)


            if data_source not in already_print_data_sources:
                already_print_data_sources[data_source] = 0

            if already_print_data_sources[data_source] < self.num_examine:
                already_print_data_sources[data_source] += 1



        if is_save_mode:
            print(f"Aggregating data for saving to {save_analysis_path}...")
            try:
                with open(save_analysis_path, 'w', encoding='utf-8') as f:
                    json.dump(all_samples_data_for_json, f, ensure_ascii=False, indent=4)
                print(f"Analysis data successfully saved as JSON to {save_analysis_path}")
            except Exception as e:
                print(f"Error! Failed to save analysis JSON data: {e}")


        return reward_tensor



@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config):
    if not ray.is_initialized():

        ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 
                                           'NCCL_DEBUG': 'WARN',
                                           "VLLM_ATTENTION_BACKEND": "XFORMERS",
                                           "WANDB_API_KEY":""}})

    ray.get(main_task.remote(config))



@ray.remote
def main_task(config):
    from verl.utils.fs import copy_local_path_from_hdfs
    from transformers import AutoTokenizer


    from pprint import pprint
    from omegaconf import OmegaConf
    pprint(OmegaConf.to_container(config, resolve=True))
    OmegaConf.resolve(config)


    local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)


    from verl.utils import hf_tokenizer
    tokenizer = hf_tokenizer(local_path)

    if "Qwen2.5" in local_path:
        if "gsm8k" in config.data.train_files or "math" in config.data.train_files:
            tokenizer.eos_token_id = 151643


    if config.actor_rollout_ref.actor.strategy == 'fsdp':
        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
        from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
        from verl.single_controller.ray import RayWorkerGroup
        ray_worker_group_cls = RayWorkerGroup

    elif config.actor_rollout_ref.actor.strategy == 'megatron':
        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
        from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
        from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
        ray_worker_group_cls = NVMegatronRayWorkerGroup

    else:
        raise NotImplementedError

    from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role

    role_worker_mapping = {
        Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
        Role.Critic: ray.remote(CriticWorker),
        Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
    }

    global_pool_id = 'global_pool'
    resource_pool_spec = {
        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
    }
    mapping = {
        Role.ActorRollout: global_pool_id,
        Role.Critic: global_pool_id,
        Role.RefPolicy: global_pool_id,
    }


    if config.reward_model.enable:
        if config.reward_model.strategy == 'fsdp':
            from verl.workers.fsdp_workers import RewardModelWorker
        elif config.reward_model.strategy == 'megatron':
            from verl.workers.megatron_workers import RewardModelWorker
        else:
            raise NotImplementedError
        role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
        mapping[Role.RewardModel] = global_pool_id

    reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)


    val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)

    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

    trainer = RayPPOTrainer(config=config,
                            tokenizer=tokenizer,
                            role_worker_mapping=role_worker_mapping,
                            resource_pool_manager=resource_pool_manager,
                            ray_worker_group_cls=ray_worker_group_cls,
                            reward_fn=reward_fn,
                            val_reward_fn=val_reward_fn)
    trainer.init_workers()
    trainer.fit()


def set_random_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)



if __name__ == '__main__':
    set_random_seed(42)
    main()
