from contextlib import nullcontext
import time
import inspect
import logging
import os
import sys
import datasets
import torch
import torch_npu
torch_npu.npu.set_compile_mode(jit_compile=False)
from torch_npu.npu.amp import autocast
from torch_npu.contrib import transfer_to_npu
import transformers
from datasets import load_dataset
from transformers.trainer_utils import get_last_checkpoint
from hetero_rl.configs_obs import MoISScriptArguments, GRPOConfig
from hetero_rl.rewards import get_reward_funcs
from hetero_rl.utils import get_tokenizer
from hetero_rl.utils.callbacks import get_callbacks
from hetero_rl.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config
from hetero_rl.utils.data_utils import custom_loading_dataset
from transformers import TrainerCallback
from pathlib import Path
from trl.extras.profiling import profiling_decorator, profiling_context
from typing import Any, Union
from async_utils_checkpoint_fixing_obs import setup_fs_queue, push_to_fs_queue, obs_download_file

from transformers import TrainerControl
import torch.nn.functional as F
import warnings
import copy

import torch.utils.data
from accelerate.utils import gather, gather_object, is_peft_model, set_seed

from torch import nn
from transformers import (
    Trainer,
    is_wandb_available,
)
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.import_utils import is_vllm_available
from trl.models import  unwrap_model_for_generation
from trl.trainer.utils import (
    pad,
)
from accelerate.utils import reduce, broadcast
from hetero_rl.Time_Delay import get_delay_sampler

if is_vllm_available():
    from vllm import LLM, SamplingParams
    from vllm.sampling_params import GuidedDecodingParams

if is_wandb_available():
    import wandb
    # print("wandb has imported")

import torch.distributed as dist
import re
from trl.trainer.utils import selective_log_softmax
from torch.nn.utils.rnn import pad_sequence
import json
import obs
logger = logging.getLogger(__name__)


def nansum(tensor, dim=None, keepdim=False):
    """
    Compute the sum of a tensor, ignoring NaNs, with support for arbitrary dimensions.
    
    Args:
        tensor (torch.Tensor): Input tensor.
        dim (int or tuple of ints, optional): The dimension or dimensions to reduce. 
            If None, sum over all elements.
        keepdim (bool): Whether the output tensor has reduced dimensions retained or not.
    
    Returns:
        torch.Tensor: Sum of tensor along specified dimension(s), ignoring NaNs.
    """
    if dim is None:
        flat_tensor = tensor.flatten()
        mask = ~torch.isnan(flat_tensor)
        if mask.sum() > 0:
            return flat_tensor[mask].sum()
        else:
            return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
    
    else:
        mask = ~torch.isnan(tensor)
        tensor_clean = torch.where(mask, tensor, torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype))
        sum_result = tensor_clean.sum(dim=dim, keepdim=keepdim)
        return sum_result
    
def nanmean(tensor, dim=None, keepdim=False):
    if dim is None:
        flat_tensor = tensor.flatten()
        mask = ~torch.isnan(flat_tensor)
        if mask.sum() > 0:
            return flat_tensor[mask].mean()
        else:
            return torch.tensor(float('nan'), device=tensor.device, dtype=tensor.dtype)
    else:
        mask = ~torch.isnan(tensor)
        valid_count = mask.sum(dim=dim, keepdim=keepdim)
        tensor_clean = torch.where(mask, tensor, torch.tensor(0.0, device=tensor.device))
        sum_result = tensor_clean.sum(dim=dim, keepdim=keepdim)

        return torch.where(
            valid_count > 0, 
            sum_result / valid_count, 
            torch.tensor(float('nan'), device=tensor.device, dtype=tensor.dtype)
        )

def merge(valid_rewards, new_rewards):
    if valid_rewards is None:
        return new_rewards
    else:
        if new_rewards is None:
            return valid_rewards
        else:
            return torch.concat([new_rewards, valid_rewards])


def merge_with_padding(valid_rewards, new_rewards, pad_token_id, left_pad=False):
    if valid_rewards is None:
        return new_rewards
    else:
        if new_rewards is None:
            return valid_rewards
        else:
            if new_rewards.shape[1] < valid_rewards.shape[1]:
                new_rewards = pad_sequence_to_length(new_rewards, valid_rewards.shape[1], pad_token_id, left_pad)
            else:
                valid_rewards = pad_sequence_to_length(valid_rewards, new_rewards.shape[1], pad_token_id, left_pad)
            return torch.concat([new_rewards, valid_rewards])


def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
    """
    pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
    input shape: [bs, seq_length]
    output shape: [bs, max_seq_length]
    (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
    """
    if tensors.shape[-1] >= max_seq_len:
        return tensors
    pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])
    return F.pad(tensors, pad_tuple, "constant", pad_token_id)

# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`):
            Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`:
            Standard deviation of the tensor, ignoring NaNs.
    """
    # variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2)  # Compute variance ignoring NaNs
    variance = nanmean((tensor - nanmean(tensor, keepdim=True)) ** 2)
    count = torch.sum(~torch.isnan(tensor))  # Count of non-NaN values
    variance *= count / (count - 1)  # Bessel's correction
    return torch.sqrt(variance)

class HeteroRLSampler(GRPOTrainer):

    def __init__(self,  delay_list, *args, **kwargs):
        self.fs_queue_path = kwargs.pop("fs_queue_path")

        kwargs.pop("optimizers", None)
        self.training_args = kwargs.get('args')
        # super().__init__(*args, optimizers=(None, None), **kwargs)
        # super().__init__(model, reward_funcs, args, train_dataset, eval_dataset, processing_class, reward_processing_classes, callbacks,
        #                  optimizers,peft_config)
        super().__init__(*args, optimizers=(None, None), **kwargs)
        # self.scale_batch = args.scale_batch
        # self.data_weight = {'train': defaultdict(list), 'eval': defaultdict(list)}
        self.rank = self.accelerator.process_index

        logger.info(
            f"[Rank {self.rank}] grpo sampler initialized. "
        )
        self.queue_dir, _ = setup_fs_queue(self.fs_queue_path)

        self.log_interval = int(os.getenv("SAMPLER_LOG_INTERVAL", "10"))
        self.sync_weights_path = Path(os.getenv("SYNC_WEIGHTS_PATH", "/tmp/async_weights.pt"))
        self.last_sync_time = 0
        self._dataloader = self.get_train_dataloader()
        self._epoch_iterator = iter(self._dataloader)
        self.batch_ids = 0
        self.model_ids = 0
        if "checkpoint" in self.model.config._name_or_path:
            self.model_ids = int(self.model.config._name_or_path.split("checkpoint-")[-1])

        self.delay_list = delay_list.__iter__()
        self.online_mode = self.training_args.online_mode
        self.sampler_id = self.training_args.sampler_id
        logger.info(f"delay_list[:20]: {delay_list[:20]}")
        access_key_id = os.environ.get('ACCESS_KEY_ID')  # AK
        secret_access_key = os.environ.get('SECRET_ACCESS_KEY')  # SK
        server = os.environ.get('OBS_SERVER')
        self.obs_client = obs.ObsClient(access_key_id=access_key_id, secret_access_key=secret_access_key, server=server)
        self.obs_sync_weights_path = os.getenv("OBS_SYNC_WEIGHTS_PATH")
        self.obs_queue_dir = os.getenv("OBS_FS_QUEUE_PATH")
        print(f"self.obs_queue_dir:{self.obs_queue_dir}")

    @profiling_decorator
    def _move_model_to_vllm(self):
        # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
        if zero_stage_3:
            import deepspeed
            # logger.info(f"sampler_script_v2_vllm.py line 150")
            gather_if_zero3 = deepspeed.zero.GatheredParameters
        else:
            gather_if_zero3 = nullcontext
            # logger.info(f"sampler_script_v2_vllm.py line 154")
        if is_peft_model(self.model):
            # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
            # merging adapters in a sharded manner is not supported.
            # TODO: does this work with FSDP?
            with gather_if_zero3(list(self.model.parameters())):
                self.model.merge_adapter()

                # Update vLLM weights while parameters are gathered
                if self.is_fsdp_enabled:  # note if using FSDP, gather_if_zero3 is nullcontext
                    # Update vLLM weights while parameters are gathered
                    # For PEFT with FSDP we need to use the memory efficient post-order traversal
                    self._sync_fsdp_params_to_vllm(self.model)
                else:
                    # DeepSpeed ZeRO-3 with PEFT
                    for name, param in self.model.named_parameters():
                        # When using PEFT, we need to recover the original parameter name and discard some parameters
                        name = name.removeprefix("base_model.model.").replace(".base_layer", "")
                        if self.model.prefix in name:
                            continue
                        # When module to save, remove its prefix and discard the original module
                        if "original_module" in name:
                            continue
                        name = name.replace("modules_to_save.default.", "")

                        if self.vllm_mode == "server" and self.accelerator.is_main_process:
                            self.vllm_client.update_named_param(name, param.data)
                        elif self.vllm_mode == "colocate":
                            llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                            llm_model.load_weights([(name, param.data)])
                # Unmerge adapters while parameters are still gathered
                self.model.unmerge_adapter()
                # Parameters will automatically be repartitioned when exiting the context
        else:
            # For non-PEFT models, simply gather (if needed) and update each parameter individually.
            if self.is_fsdp_enabled:
                self._sync_fsdp_params_to_vllm(self.model)  # use memory-efficient post-order traversal for FSDP
                # logger.info(f"sampler_script_v2_vllm.py line 191")
            else:
                # logger.info(f"sampler_script_v2_vllm.py line 193")
                for name, param in self.model.named_parameters():
                    with gather_if_zero3([param]):
                        if self.vllm_mode == "server" and self.accelerator.is_main_process:
                            self.vllm_client.update_named_param(name, param.data)
                        elif self.vllm_mode == "colocate":
                            llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                            llm_model.load_weights([(name, param.data)])

        # Reset cache on vLLM
        if self.vllm_mode == "server" and self.accelerator.is_main_process:
            # logger.info(f"sampler_script_v2_vllm.py line 206")
            self.vllm_client.reset_prefix_cache()
        elif self.vllm_mode == "colocate":
            # logger.info(f"sampler_script_v2_vllm.py line 209")
            self.llm.reset_prefix_cache()

    # Get the per-token log probabilities for the completions for the model and the reference model
    @profiling_decorator
    def new_get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor:
        batch_size = batch_size or input_ids.size(0)  # Chunk inputs into smaller batches to reduce memory peak
        all_logps = []
        for i in range(0, input_ids.size(0), batch_size):
            input_ids_batch = input_ids[i : i + batch_size]
            attention_mask_batch = attention_mask[i : i + batch_size]

            # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
            logits = model(
                input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1
            ).logits
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
            input_ids_batch = input_ids_batch[:, -logits_to_keep:]
            # Divide logits by sampling temperature.
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            logits = logits / self.temperature
            logps = selective_log_softmax(logits, input_ids_batch)  # compute logprobs for the input tokens
            all_logps.append(logps)
        return torch.cat(all_logps, dim=0)

    @profiling_decorator
    def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
        device = self.accelerator.device
        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)

        # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
        keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
        reward_kwargs = {key: [example[key] for example in inputs] for key in keys}

        for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
        ):
            with profiling_context(self, reward_func_name):
                if isinstance(reward_func, nn.Module):  # Module (no PretrainedModel) for compat with compiled models
                    if is_conversational(inputs[0]):
                        messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                        texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                    else:
                        texts = [p + c for p, c in zip(prompts, completions)]
                    reward_inputs = reward_processing_class(
                        text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                    )
                    reward_inputs = super()._prepare_inputs(reward_inputs)
                    with torch.inference_mode():
                        rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
                else:
                    output_reward_func = reward_func(
                        prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
                    )
                    # Convert None values to NaN
                    output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]

                    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # If all reward functions return None for a given row, issue a detailed warning
        if torch.isnan(rewards_per_func).all(dim=1).any():
            nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
            row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
            row_reward_kwargs["prompt"] = prompts[nan_row_idx]
            row_reward_kwargs["completion"] = completions[nan_row_idx]
            warnings.warn(
                f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
                "Please ensure that at least one reward function returns a valid reward."
            )

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        # rewards_per_func = gather(rewards_per_func)
        return rewards_per_func

    def generate_and_score_completions(
        self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        mode = "train"

        prompts = [x["prompt"] for x in inputs]
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
        prompt_inputs = self.processing_class(
            text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
        )
        # prompt_inputs = super()._prepare_inputs(prompt_inputs)
        prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        if self.max_prompt_length is not None:
            # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
            # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
            # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]
            prompts_text = self.processing_class.batch_decode(
                prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
            )
            prompts_text = [
                re.sub(rf"^({re.escape(self.processing_class.pad_token)})+", "", text) for text in prompts_text
            ]

        # Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            if self.guided_decoding_regex:
                guided_decoding = GuidedDecodingParams(backend="outlines", regex=self.guided_decoding_regex)
            else:
                guided_decoding = None

            generation_kwargs = {
                "n": 1,  # vLLM on each GPU generates only 1 in colocate mode
                "repetition_penalty": self.repetition_penalty,
                "temperature": self.temperature,
                "top_p": self.top_p,
                "top_k": -1 if self.top_k is None else self.top_k,
                "min_p": 0.0 if self.min_p is None else self.min_p,
                "max_tokens": self.max_completion_length,
                "guided_decoding": guided_decoding,

            }
            if self.args.generation_kwargs is not None:
                generation_kwargs.update(self.args.generation_kwargs)
            sampling_params = SamplingParams(**generation_kwargs)

            if self.vllm_tensor_parallel_size > 1:
                print("should not be here!")
                # Gather prompts from all ranks in the TP group and flatten.
                # Each rank starts with its own prompts; after gathering, all ranks see the full group set.
                orig_size = len(prompts_text)
                gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
                torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
                all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
            else:
                all_prompts_text = prompts_text

            with profiling_context(self, "vLLM.generate"):
                all_outputs = self.llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False)
            completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

            if self.vllm_tensor_parallel_size > 1:
                # Slice completions for this rank within its TP group.
                # Each rank generates all outputs — we keep only our share.
                local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
                tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
                completion_ids = completion_ids[tp_slice]
            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            # Regular generation path
            with unwrap_model_for_generation(
                self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
            ) as unwrapped_model:
                with (
                    FSDP.summon_full_params(self.model_wrapped, recurse=False)
                    if self.is_fsdp_enabled
                    else nullcontext()
                ):
                    prompt_completion_ids = unwrapped_model.generate(
                        prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
                    )

            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

        # Mask everything after the first EOS token
        is_eos = completion_ids == self.processing_class.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

        # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
        # to re-tokenize completions if the reward is computed from tokens.
        completion_ids_list = [
            [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
        ]

        # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
        completion_lengths = completion_mask.sum(1)

        # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
        if self.mask_truncated_completions:
            truncated_completions = ~is_eos.any(dim=1)
            completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()

        # Concatenate prompt_mask with completion_mask for logit computation
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
        batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size

        with torch.no_grad():
            # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
            # old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
            # per_token_logps.detach() instead.
            if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
                old_per_token_logps = self._get_per_token_logps(
                    self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
                )
            else:
                old_per_token_logps = None

            sampler_per_token_logps = self._get_per_token_logps(
                self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
            ) if self.loss_type in ["grpo","bnpo", "dr_grpo", "gspo", "gepo", "EqP"] else None
            
            # Compute the per-token log probabilities for the reference model
            if self.beta != 0.0:
                if self.ref_model is not None:
                    ref_per_token_logps = self._get_per_token_logps(
                        self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
                    )
                else:
                    with self.accelerator.unwrap_model(self.model).disable_adapter():
                        ref_per_token_logps = self._get_per_token_logps(
                            self.model, prompt_completion_ids, attention_mask, logits_to_keep
                        )
            else:
                ref_per_token_logps = None
        # Decode the generated completions
        completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
        if is_conversational(inputs[0]):
            completions = []
            for prompt, completion in zip(prompts, completions_text):
                bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
                completions.append([{"role": "assistant", "content": bootstrap + completion}])
        else:
            completions = completions_text

        # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
        # important because rewards will be normalized per group, and completions are distributed. We will later slice
        # rewards_per_func to extract each process's subset.
        rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)

        # Apply weights to each reward function's output and sum
        # rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
        rewards = nansum((rewards_per_func * self.reward_weights.to(device).unsqueeze(0)), dim=1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
        is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))

        # mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        # std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        mean_grouped_rewards = torch.cat([mean_grouped_rewards] * self.num_generations, dim=0)
        std_grouped_rewards = torch.cat([std_grouped_rewards] * self.num_generations, dim=0)  
        if self.loss_type == "bnpo":
            min_value, max_value=0.0, 1.0
            rewards = (rewards - min_value)/(max_value - min_value)
            batch_mean = mean_grouped_rewards.mean()
            batch_var = mean_grouped_rewards.var()
            a = (batch_mean*(1-batch_mean)/batch_var-1)*batch_mean if batch_var > 0 else torch.tensor(0.0, device=device)
            b = (batch_mean*(1-batch_mean)/batch_var-1)*(1-batch_mean) if batch_var > 0 else torch.tensor(0.0, device=device)
            alpha = torch.clamp(1+a/3, min=1.0)
            beta = torch.clamp(1+b/3, min=1.0)
            weight = torch.distributions.Beta(alpha, beta).log_prob(mean_grouped_rewards).exp()
            weight = torch.clamp(1/weight, min=0, max=1e6)
            advantages = weight * (rewards - mean_grouped_rewards)
        else:
            # Normalize the rewards to compute the advantages
            advantages = rewards - mean_grouped_rewards
            if self.scale_rewards:
                advantages = advantages / (std_grouped_rewards + 1e-4)

        # Log the metrics
        if mode == "train":
            self.state.num_input_tokens_seen += attention_mask.sum().item()
        self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
        self._metrics[mode]['model_ids'] = [self.model_ids]

        # Log completion lengths, mean, min, max
        agg_completion_lengths = completion_lengths
        self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
        self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
        self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())

        # Identify sequences that terminated with EOS and log their lengths
        agg_terminated_with_eos = is_eos.any(dim=1)
        term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
        clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)
        self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
        if len(term_completion_lengths) == 0:  # edge case where no terminated sequences are found
            term_completion_lengths = torch.zeros(1, device=device)
        self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
        self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
        self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())

        # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
        for i, reward_func_name in enumerate(self.reward_func_names):
            # mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
            mean_rewards = nanmean(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
            std_rewards = nanstd(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
        self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
        self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
        self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())

        # Log prompt and completion texts
        self._textual_logs["prompt"].extend(prompts_text)
        self._textual_logs["completion"].extend(completions_text)
        for i, name in enumerate(self.reward_func_names):
            self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
        self._textual_logs["advantages"].extend(advantages)

        if self.log_completions:
            prompts_to_log = prompts_text
            completions_to_log = completions_text
            # if self.accelerator.is_main_process:
            if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
                import pandas as pd

                # For logging
                table = {
                    "model_sync_times": [str(self.model_ids)]* len(rewards),
                    "step": [str(self.batch_ids)] * len(rewards),
                    "prompt": prompts_to_log,
                    "completion": completions_to_log,
                    "reward": rewards.tolist(),
                }
                df = pd.DataFrame(table)
                wandb.log({"completions": wandb.Table(dataframe=df)})

        self.batch_ids+=1
        self_metrics = copy.deepcopy(self._metrics)
        self._metrics[mode].clear()
        return {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "old_per_token_logps": old_per_token_logps,
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": advantages,
            "model_ids": self.model_ids,
            "sampler_per_token_logps": sampler_per_token_logps,
            "metrics": self_metrics,
        }
    
    @profiling_decorator
    def _sync_model_weights(self, current_idx, file_path, timeout: int = 600):
        try:
            start_time = time.time()
            while time.time() - start_time < timeout:  
                if not self.sync_weights_path.exists():
                    if self.online_mode:
                        time.sleep(0.1)
                        continue
                    else:
                        return          
                current_mtime = self.sync_weights_path.stat().st_mtime
                if current_mtime > self.last_sync_time:
                    logger.info(f"[Sampler Rank-{self.rank}] Detected new weights from {self.sync_weights_path}. Loading...")
                    obs_download_file(self.obs_client, self.obs_sync_weights_path, self.sync_weights_path)
                    should_update, global_step, state_dict = torch.load(self.sync_weights_path, map_location="cpu")
                    if should_update:
                        with open(file_path, 'w') as f:
                            json.dump(current_idx, f)
                        print(f"save current index of training dataset: {current_idx}")

                    self.model.load_state_dict(state_dict)
                    self._move_model_to_vllm()
                    self.last_sync_time = current_mtime
                    old_ids = self.model_ids
                    self.model_ids = global_step
                    logger.info(f"[Sampler Rank-{self.rank}] New weights loaded successfully. model_ids:{old_ids}->{self.model_ids}")
                    return
                # In online mode, sampler will keep waiting for the latest model weight
                if self.online_mode:
                    time.sleep(0.1)
                    continue
                else:
                    logger.info(f"[Sampler Rank-{self.rank}] Weights have not been updated yet. Keep using old weights for sampling.")
                    return
            print(f"WARNING [Rank {self.rank}]: Timed out after {time.time() - start_time} (Max-{timeout})s waiting for model weight from '{self.sync_weights_path}'.")
            error_message = (
                f"[Rank {self.rank}] CRITICAL: Timed out after {timeout} seconds "
                f"waiting for model weight from '{self.sync_weights_path}'. "
                "The learner process(es) might be down or stuck. Aborting training."
            )
            logger.error(error_message)
            raise RuntimeError(error_message)
        except FileNotFoundError:
            pass
        except RuntimeError:
            raise
        except Exception as e:
            logger.error(f"[Sampler] Rank-{self.rank} Error loading weights: {e}")
            time.sleep(1)

    def _get_next_batch(self, index):
        try:
            index += 1
            return next(self._epoch_iterator), index
        except StopIteration:
            logger.info(f"[Sampler] Rank-{self.rank} Dataset depleted. Re-creating dataloader iterator.")
            self._epoch_iterator = iter(self._dataloader)
            index = 0
            return next(self._epoch_iterator), index

    def run_sampling_loop(self, file_path):

        logger.info(f"*** Starting Sampler Loop (PID: {os.getpid()}) on device: {self.accelerator.device} ***")
        batch_counter = 0

        delay_time = broadcast(torch.tensor([next(self.delay_list)], device=self.accelerator.device, dtype=torch.float64),from_process=0)

        logger.info(f"first_delay_time:{delay_time}")
        last_time = broadcast(torch.tensor([time.time()], device=self.accelerator.device, dtype=torch.float64),from_process=0)
        logger.info(f"first_last_time:{last_time}")
        current_index = -1

        while True:
            now_time =  broadcast(torch.tensor([time.time()], device=self.accelerator.device, dtype=torch.float64),from_process=0)
            wait_time =now_time - last_time
            if wait_time > delay_time:
                self._sync_model_weights(current_index, file_path)
                print(f"[RANK-{self.rank}] model sync done, delay {wait_time.item():.1f}(>{delay_time.item():.1f}) second")
                last_time =  broadcast(torch.tensor([time.time()], device=self.accelerator.device, dtype=torch.float64),from_process=0)
                delay_time = broadcast(torch.tensor([next(self.delay_list)], device=self.accelerator.device, dtype=torch.float64),from_process=0)
            else:
                print(f"[RANK-{self.rank}] model sync: {(wait_time/delay_time*100).item():.1f}%, spend time: {wait_time.item():.1f}(<{delay_time.item():.1f}) second")

            batch, current_index = self._get_next_batch(current_index)

            with torch.no_grad():
                self.control = TrainerControl()
                # time_start = time.time()
                rollout_data = self.generate_and_score_completions(batch)
                # time_end = time.time()

            # time_save = reduce(torch.tensor([time.time()],device=self.accelerator.device, dtype=torch.float64))
            cpu_rollout_data = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in rollout_data.items()}
            push_to_fs_queue(self, cpu_rollout_data)

            batch_counter += 1
            if batch_counter % self.log_interval == 0:
                queue_size = len(list(self.queue_dir.glob("data_*.pt")))
                logger.info(f"[RANK-{self.rank}] Generated and wrote batch #{batch_counter}. Approximate queue size: {queue_size} ")


def main(script_args, training_args, model_args):

    rank = training_args.local_rank
    delay_sampler=get_delay_sampler(script_args)
    delay_list = delay_sampler.get_delay_list(n=50000)
    print(f"[RANK-{rank}] delay_list[:20]: {delay_list[:20]}")


    rank =training_args.local_rank
    set_seed(training_args.seed)

    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process a small summary
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"[Rank={rank}] Random seed {training_args.seed}")
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Script parameters {script_args}")
    logger.info(f"Training parameters {training_args}")

    # Check for last checkpoint
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

    if "wandb" in training_args.report_to:
        init_wandb_training(training_args)
        if rank==0:
            current_file_path = __file__
            current_file_name = os.path.basename(current_file_path)
            wandb.login()
            wandb.init(project=os.environ["WANDB_PROJECT"],
                       entity = os.environ["WANDB_ENTITY"],
                       # config=dict(training_args),
                       name=current_file_name
                       )

    ################
    # Load tokenizer
    ################
    tokenizer = get_tokenizer(model_args, training_args)


    # handle dataset
    # Load the dataset
    if 'simplelr_qwen_level3to5' in script_args.dataset_name:
        dataset = custom_loading_dataset(script_args.dataset_name, max_length=training_args.max_prompt_length, tokenizer=tokenizer)

    else:
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    # Get reward functions from the registry
    reward_funcs = get_reward_funcs(script_args)

    # Format into conversation
    def make_conversation(example):
        prompt = []
        # if training_args.system_prompt is not None:
        #     prompt.append({"role": "system", "content": training_args.system_prompt})
        if script_args.use_think:
            prompt.append({"role": "system", "content": script_args.system_prompt_think})
            prompt.append({"role": "user", "content": example["problem"]})
        else:
            prompt.append({"role": "system", "content": script_args.system_prompt_nothink})
            prompt.append({"role": "user", "content": example["problem"] + "/no_think"})

        return {"prompt": prompt}

    dataset = dataset.map(make_conversation)

    for split in dataset:
        if "messages" in dataset[split].column_names:
            dataset[split] = dataset[split].remove_columns("messages")

    logger.info("*** Initializing model kwargs ***")
    # torch_dtype = (
    #     model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    # )
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch.float16,
        use_cache=False if training_args.gradient_checkpointing else True,
    )
    training_args.model_init_kwargs = model_kwargs

    fs_queue_path = os.getenv("FS_QUEUE_PATH", "/Qwen3-8B")

    if training_args.num_samplers != 1:
        shuffled_dataset = dataset[script_args.dataset_train_split].shuffle(seed=training_args.seed)
        sampling_dataset = shuffled_dataset.shard(num_shards=training_args.num_samplers, index=training_args.sampler_id)
    else:
        sampling_dataset = dataset[script_args.dataset_train_split]
    
    # file_path = os.path.join(training_args.output_dir, 'last_index.json')
    file_path = Path(training_args.output_dir) / 'last_index.json'
    if training_args.resume_from_checkpoint:
        if training_args.resume_from_checkpoint == "True":
            if file_path.exists():
                with open(file_path, 'r') as f:
                    last_index = json.load(f)
                start_idx = last_index + 1
                logger.info(f"Start training from existing dataset index {start_idx}")
                if start_idx >= len(sampling_dataset):
                    logger.info("training complete!")
                    exit()
                sampling_dataset = sampling_dataset.select(range(start_idx, len(sampling_dataset)))
            else:
                logger.info(f"No index file found! Start from scratch.")
    
    trainer = HeteroRLSampler(
        delay_list=delay_list,
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=sampling_dataset,
        eval_dataset=None,
        peft_config=get_peft_config(model_args),
        callbacks=get_callbacks(training_args, model_args),
        processing_class=tokenizer,
        fs_queue_path=fs_queue_path,
    )


    class ResetDataLoader(TrainerCallback):
        trainer = None

        def on_epoch_end(self, args, state, control, **kwargs):
            """
            Event called at the end of an epoch.
            """
            if hasattr(self.trainer, '_epoch_iterator'):
                print('reset epoch iter in trainer')
                del self.trainer._epoch_iterator

    ResetDataLoader.trainer = trainer
    trainer.add_callback(ResetDataLoader)

    ###############
    # Training loop
    ###############
    logger.info("*** Starting Sampler Sampling Loop ***")
    trainer.run_sampling_loop(file_path)

if __name__ == "__main__":
    parser = TrlParser((MoISScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config(fail_with_unknown_args=False)
    main(script_args, training_args, model_args)


