                                                      
                                                                 

from collections import defaultdict
from typing_extensions import override
from types import SimpleNamespace

from megatron.training.global_vars import get_args

from gpatch.training.utils import print_with_rank_and_datetime
from gpatch.training.v3.ppo_actor import PPOActorTrainerV3


class MathRLActorTrainer(PPOActorTrainerV3):

    @override
    def is_rollout_batch_accepted(self, rb):
        """
        check if all batches are accepted
        
        rb: a list of rollout batches

        returns: no return
        """
        return rb['sample_useful'][0].item()

    @override
    def update_replay_samples_dict(self, rb, sample_idx):
        """
        update replay samples dict
        
        rb: a list of rollout batches, assert 
        sample_idx: index of the sample

        returns: no return
        """
        args = get_args()
        max_replay_times = args.ppo_dynamic_sampling_max_replay

        if sample_idx not in self.replay_samples_dict:
            prompt_lengths = rb['prompt_lengths']
            assert prompt_lengths.ndim == 1
            lpad_lens_list = prompt_lengths.tolist()
            response_tokens = rb["response_tokens"]
            prompt_tokens = response_tokens[0][:lpad_lens_list[0]]
                                                 
            gt_label = rb['gt_label']
            train_data_consuming_progress = rb.get('train_data_consuming_progress', None)
            prompt_data = {
                "prompt_token_ids": [dict(prompt_token_ids=prompt_tokens.tolist())],
                "lpad_lens": prompt_lengths[0].view(1),
                "gt_label": gt_label[0].view(1),
                "train_data_consuming_progress": train_data_consuming_progress,
            }
            self.replay_samples_dict.update(
                {sample_idx: SimpleNamespace(prompt_data=prompt_data, replay_times=1)})
                    
            self.replay_queue.append(sample_idx)
        else:
            if self.replay_samples_dict[sample_idx].replay_times > max_replay_times:
                               
                removed_value = self.replay_samples_dict.pop(sample_idx)
                print_with_rank_and_datetime(
                    f"failure. give up replay sample {sample_idx=} replay_times {removed_value.replay_times}"
                )
            else:
                self.replay_samples_dict[sample_idx].replay_times += 1
                        
                self.replay_queue.append(sample_idx)

    @override
    def is_time_to_replay_samples(self, replay_queue, replay_samples_dict):
        """
        check if need to replay samples
        replay_queue is a list that contains sample_idx
        replay_samples_dict is a dict that {sample_idx: sample_data}

        returns: a list of bool
        """
                                    
                            
                                    
                                              
                             
                        

                              
        return [True] * len(replay_queue)
