# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Metrics related to the PPO trainer.
"""

from collections import defaultdict
from functools import partial
from typing import Any, Callable, Optional

import numpy as np
import torch

from verl import DataProto
from verl.trainer.ppo.core_algos import AdvantageEstimator
from verl.utils import as_torch_index
from verl.utils.import_utils import deprecated


@deprecated("verl.utils.metric.reduce_metrics")
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
    """
    Reduces a dictionary of metric lists by computing the mean of each list.

    Args:
        metrics: A dictionary mapping metric names to lists of metric values.

    Returns:
        A dictionary with the same keys but with each list replaced by its mean value.

    Example:
        >>> metrics = {"loss": [1.0, 2.0, 3.0], "accuracy": [0.8, 0.9, 0.7]}
        >>> reduce_metrics(metrics)
        {"loss": 2.0, "accuracy": 0.8}
    """
    from verl.utils.metric import reduce_metrics

    return reduce_metrics(metrics)


_REWEIGHT_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT,
)
_REWEIGHT_TOKENCOUNT_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT_TOKENCOUNT,
)
_REWEIGHT_CLIP_MIN_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT_CLIP_MIN,
)
_REWEIGHT_RMS_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT_RMS,
)
_FUTURE_ONLY_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT_FUTURE_ONLY,
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT_FUTURE_ONLY_NOSTD,
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_GLOBAL_NORM_REWEIGHT_FUTURE_ONLY_NOSTD,
)
_FUTURE_ONLY_TOKENCOUNT_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_REWEIGHT_FUTURE_ONLY_TOKENCOUNT,
)
_INVCOUNT_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_INVCOUNT,
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_INVCOUNT_NOSTD,
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_GLOBAL_NORM_INVCOUNT_NOSTD,
)
_INVCOUNT_TOKENCOUNT_ESTIMATORS = (
    AdvantageEstimator.GRPO_VERK_STEP_REWARD_STEP_NORM_INVCOUNT_TOKENCOUNT,
)


def _compute_verk_weights(
    turn_successes: torch.Tensor, g: torch.Tensor, eps: float = 1e-8
) -> tuple[torch.Tensor, torch.Tensor]:
    reached = turn_successes != -1
    success = turn_successes == 1
    G = int(g.max().item()) + 1 if g.numel() > 0 else 0
    K = turn_successes.shape[1]
    sum_reached = torch.zeros((G, K), device=turn_successes.device, dtype=torch.float32).index_add_(
        0, g, reached.to(dtype=torch.float32)
    )
    sum_success = torch.zeros((G, K), device=turn_successes.device, dtype=torch.float32).index_add_(
        0, g, success.to(dtype=torch.float32)
    )
    p_hat = torch.zeros_like(sum_success)
    step_mask = sum_reached > 0
    p_hat[step_mask] = sum_success[step_mask] / sum_reached[step_mask]
    one_minus = (1.0 - p_hat).clamp(min=eps, max=1.0)
    if K == 1:
        w = torch.ones((G, 1), device=turn_successes.device, dtype=torch.float32)
    else:
        prefix = torch.cumprod(one_minus, dim=1)
        suffix = torch.cumprod(one_minus.flip(1), dim=1).flip(1)
        left = torch.ones_like(one_minus)
        right = torch.ones_like(one_minus)
        left[:, 1:] = prefix[:, :-1]
        right[:, :-1] = suffix[:, 1:]
        w = left * right
    return w, step_mask


def _compute_future_only_weights(
    turn_successes: torch.Tensor, g: torch.Tensor, eps: float = 1e-8
) -> tuple[torch.Tensor, torch.Tensor]:
    reached = turn_successes != -1
    success = turn_successes == 1
    G = int(g.max().item()) + 1 if g.numel() > 0 else 0
    K = turn_successes.shape[1]
    sum_reached = torch.zeros((G, K), device=turn_successes.device, dtype=torch.float32).index_add_(
        0, g, reached.to(dtype=torch.float32)
    )
    sum_success = torch.zeros((G, K), device=turn_successes.device, dtype=torch.float32).index_add_(
        0, g, success.to(dtype=torch.float32)
    )
    p_hat = torch.zeros_like(sum_success)
    step_mask = sum_reached > 0
    p_hat[step_mask] = sum_success[step_mask] / sum_reached[step_mask]
    one_minus_group = torch.ones_like(p_hat)
    one_minus_group[step_mask] = (1.0 - p_hat[step_mask]).clamp(min=eps, max=1.0)
    w_group = torch.ones_like(one_minus_group)
    if K > 1:
        suffix = torch.cumprod(one_minus_group.flip(1), dim=1).flip(1)
        w_group[:, :-1] = suffix[:, 1:]
    return w_group, step_mask


def _compute_invcount_weights(turn_successes: torch.Tensor, g: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    reached = turn_successes != -1
    G = int(g.max().item()) + 1 if g.numel() > 0 else 0
    K = turn_successes.shape[1]
    count_reached = torch.zeros((G, K), device=turn_successes.device, dtype=torch.float32).index_add_(
        0, g, reached.to(dtype=torch.float32)
    )
    w = torch.zeros_like(count_reached)
    nonzero = count_reached > 0
    w[nonzero] = 1.0 / count_reached[nonzero]
    step_mask = count_reached > 0
    return w, step_mask


def _mean_normalize_weights(w: torch.Tensor, step_mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    step_counts = step_mask.sum(dim=1).clamp_min(1.0)
    mean_w = (w * step_mask).sum(dim=1) / step_counts
    mean_w = mean_w.clamp_min(eps)
    return w / mean_w.unsqueeze(1)


def _rms_normalize_weights(w: torch.Tensor, step_mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    step_counts = step_mask.sum(dim=1).clamp_min(1.0)
    rms_w = torch.sqrt((w * w * step_mask).sum(dim=1) / step_counts)
    rms_w = rms_w.clamp_min(eps)
    return w / rms_w.unsqueeze(1)


def _token_count_mean_normalize_weights(
    w: torch.Tensor, g: torch.Tensor, turn_ids: torch.Tensor, response_mask: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
    B = turn_ids.shape[0]
    K = w.shape[1]
    token_mask = response_mask.to(dtype=torch.float32)
    token_counts_bk = torch.zeros((B, K), device=w.device, dtype=torch.float32)
    token_counts_bk.scatter_add_(1, turn_ids, token_mask)
    token_counts_gk = torch.zeros((w.shape[0], K), device=w.device, dtype=torch.float32).index_add_(
        0, g, token_counts_bk
    )
    denom = token_counts_gk.sum(dim=1).clamp_min(1.0)
    mean_w = (w * token_counts_gk).sum(dim=1) / denom
    mean_w = mean_w.clamp_min(eps)
    return w / mean_w.unsqueeze(1)


def _compute_response_info(batch: DataProto) -> dict[str, Any]:
    """
    Computes information about prompts and responses from a batch.

    This is an internal helper function that extracts masks and lengths for prompts and responses.

    Args:
        batch: A DataProto object containing batch data with responses and attention masks.

    Returns:
        A dictionary containing:
            - response_mask: Attention mask for the response tokens
            - prompt_length: Tensor of prompt lengths for each item in the batch
            - response_length: Tensor of response lengths for each item in the batch
    """
    response_length = batch.batch["responses"].shape[-1]

    prompt_mask = batch.batch["attention_mask"][:, :-response_length]
    response_mask = batch.batch["attention_mask"][:, -response_length:]

    prompt_length = prompt_mask.sum(-1).float()
    response_length = response_mask.sum(-1).float()  # (batch_size,)

    return dict(
        response_mask=response_mask,
        prompt_length=prompt_length,
        response_length=response_length,
    )


def compute_data_metrics(
    batch: DataProto,
    use_critic: bool = True,
    adv_estimator: Optional[AdvantageEstimator] = None,
    algo_config: Optional[Any] = None,
) -> dict[str, Any]:
    """
    Computes various metrics from a batch of data for PPO training.

    This function calculates metrics related to scores, rewards, advantages, returns, values,
    and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
    for each metric category.

    Args:
        batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
        use_critic: Whether to include critic-specific metrics. Defaults to True.
        adv_estimator: Advantage estimator to gate step-weight metrics. Defaults to None.
        algo_config: Algorithm config for weight normalization settings. Defaults to None.

    Returns:
        A dictionary of metrics including:
            - critic/score/mean, max, min: Statistics about sequence scores
            - critic/rewards/mean, max, min: Statistics about sequence rewards
            - critic/advantages/mean, max, min: Statistics about advantages
            - critic/returns/mean, max, min: Statistics about returns
            - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
            - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
            - response_length/mean, max, min, clip_ratio: Statistics about response lengths
            - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
            - num_turns/mean, max, min: Statistics about the number of multi-turn conversations
            - assistant_turns/mean, max, min: Statistics about the number of assistant turns
    """
    sequence_score = batch.batch["token_level_scores"].sum(-1)
    sequence_reward = batch.batch["token_level_rewards"].sum(-1)

    advantages = batch.batch["advantages"]
    returns = batch.batch["returns"]

    max_response_length = batch.batch["responses"].shape[-1]

    prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
    response_mask = batch.batch["response_mask"].bool()

    max_prompt_length = prompt_mask.size(-1)

    response_info = _compute_response_info(batch)
    prompt_length = response_info["prompt_length"]
    response_length = response_info["response_length"]

    aborted_mask = (response_length == 0).bool()
    non_aborted_mask = ~aborted_mask

    non_aborted_sequence_score = sequence_score[non_aborted_mask]
    non_aborted_sequence_reward = sequence_reward[non_aborted_mask]

    score_mean = torch.mean(non_aborted_sequence_score).detach().item()
    score_max = torch.max(non_aborted_sequence_score).detach().item()
    score_min = torch.min(non_aborted_sequence_score).detach().item()

    reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
    reward_max = torch.max(non_aborted_sequence_reward).detach().item()
    reward_min = torch.min(non_aborted_sequence_reward).detach().item()

    valid_adv = torch.masked_select(advantages, response_mask)
    valid_returns = torch.masked_select(returns, response_mask)

    if use_critic:
        values = batch.batch["values"]
        valid_values = torch.masked_select(values, response_mask)
        return_diff_var = torch.var(valid_returns - valid_values)
        return_var = torch.var(valid_returns)

    # Aborted samples and non-aborted response length statistics
    # response_length_non_aborted/*: statistics computed on non-aborted samples only
    aborted_ratio = torch.mean(aborted_mask.float()).detach().item()

    non_aborted_response_length = response_length[non_aborted_mask]
    if non_aborted_response_length.numel() > 0:
        non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
        non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
        non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
        non_aborted_response_length_clip_ratio = (
            torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
        )
    else:
        raise ValueError("All samples are aborted, this should not happen.")

    metrics = {
        # score
        "critic/score/mean": score_mean,
        "critic/score/max": score_max,
        "critic/score/min": score_min,
        # reward
        "critic/rewards/mean": reward_mean,
        "critic/rewards/max": reward_max,
        "critic/rewards/min": reward_min,
        # adv
        "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
        "critic/advantages/max": torch.max(valid_adv).detach().item(),
        "critic/advantages/min": torch.min(valid_adv).detach().item(),
        # returns
        "critic/returns/mean": torch.mean(valid_returns).detach().item(),
        "critic/returns/max": torch.max(valid_returns).detach().item(),
        "critic/returns/min": torch.min(valid_returns).detach().item(),
        **(
            {
                # values
                "critic/values/mean": torch.mean(valid_values).detach().item(),
                "critic/values/max": torch.max(valid_values).detach().item(),
                "critic/values/min": torch.min(valid_values).detach().item(),
                # vf explained var
                "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
            }
            if use_critic
            else {}
        ),
        # response length
        "response_length/mean": torch.mean(response_length).detach().item(),
        "response_length/max": torch.max(response_length).detach().item(),
        "response_length/min": torch.min(response_length).detach().item(),
        "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
        .detach()
        .item(),
        # response length (non-aborted only)
        # These statistics exclude aborted samples to avoid skew from zeros
        "response_length_non_aborted/mean": non_aborted_response_length_mean,
        "response_length_non_aborted/max": non_aborted_response_length_max,
        "response_length_non_aborted/min": non_aborted_response_length_min,
        "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
        # aborted ratio
        # Fraction of samples whose response length is zero
        "response/aborted_ratio": aborted_ratio,
        # prompt length
        "prompt_length/mean": torch.mean(prompt_length).detach().item(),
        "prompt_length/max": torch.max(prompt_length).detach().item(),
        "prompt_length/min": torch.min(prompt_length).detach().item(),
        "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
    }

    # multi-turn conversation
    if "__num_turns__" in batch.non_tensor_batch:
        num_turns = batch.non_tensor_batch["__num_turns__"]
        metrics["num_turns/min"] = num_turns.min()
        metrics["num_turns/max"] = num_turns.max()
        metrics["num_turns/mean"] = num_turns.mean()

    if "__assistant_turns__" in batch.non_tensor_batch:
        assistant_turns = batch.non_tensor_batch["__assistant_turns__"]
        metrics["assistant_turns/min"] = assistant_turns.min()
        metrics["assistant_turns/max"] = assistant_turns.max()
        metrics["assistant_turns/mean"] = assistant_turns.mean()

    if "tool_call_counts" in batch.non_tensor_batch:
        tool_call_counts = batch.non_tensor_batch["tool_call_counts"]
        metrics["tool_call_counts/min"] = tool_call_counts.min()
        metrics["tool_call_counts/max"] = tool_call_counts.max()
        metrics["tool_call_counts/mean"] = tool_call_counts.mean()

    if "turn_successes" in batch.batch:
        turn_successes = batch.batch["turn_successes"]
        if turn_successes.dim() == 2 and turn_successes.numel() > 0:
            reached = turn_successes != -1
            success = turn_successes == 1
            sum_reached = reached.sum(dim=0).to(dtype=torch.float32)
            sum_success = success.sum(dim=0).to(dtype=torch.float32)
            p_cond = torch.full_like(sum_success, float("nan"))
            reached_mask = sum_reached > 0
            p_cond[reached_mask] = sum_success[reached_mask] / sum_reached[reached_mask]
            for k in range(p_cond.shape[0]):
                metrics[f"turn_success/p_turn{k}_cond"] = p_cond[k].detach().item()

    if "turn_successes" in batch.batch and "uid" in batch.non_tensor_batch:
        turn_successes = batch.batch["turn_successes"]
        if turn_successes.dim() == 2 and turn_successes.numel() > 0:
            g = as_torch_index(batch.non_tensor_batch["uid"], device=turn_successes.device)
            if g.numel() == turn_successes.shape[0] and g.numel() > 0:
                reached = turn_successes != -1
                G = int(g.max().item()) + 1 if g.numel() > 0 else 0
                K = turn_successes.shape[1]
                count_reached = torch.zeros((G, K), device=turn_successes.device, dtype=torch.float32).index_add_(
                    0, g, reached.to(dtype=torch.float32)
                )
                for k in range(K):
                    mask = count_reached[:, k] > 0
                    if torch.any(mask):
                        metrics[f"turn_reached/mean_responses_turn{k}_cond"] = (
                            count_reached[mask, k].mean().detach().item()
                        )
                    else:
                        metrics[f"turn_reached/mean_responses_turn{k}_cond"] = float("nan")

    if "assistant_turn_rewards" in batch.batch and "turn_successes" in batch.batch:
        step_rewards = batch.batch["assistant_turn_rewards"]
        turn_successes = batch.batch["turn_successes"]
        if (
            step_rewards.dim() == 2
            and step_rewards.numel() > 0
            and turn_successes.dim() == 2
            and turn_successes.shape == step_rewards.shape
        ):
            reached = turn_successes != -1
            step_rewards = step_rewards.to(dtype=torch.float32)
            for k in range(step_rewards.shape[1]):
                mask = reached[:, k]
                if torch.any(mask):
                    std_k = torch.std(step_rewards[mask, k], unbiased=False)
                    metrics[f"step_reward/std_turn{k}_cond"] = std_k.detach().item()
                else:
                    metrics[f"step_reward/std_turn{k}_cond"] = float("nan")

    if "assistant_turn_ids" in batch.batch:
        turn_ids = batch.batch["assistant_turn_ids"]
        if turn_ids.dim() == 2 and turn_ids.shape == response_mask.shape and turn_ids.numel() > 0:
            reached = None
            if "turn_successes" in batch.batch and batch.batch["turn_successes"].dim() == 2:
                turn_successes = batch.batch["turn_successes"]
                reached = turn_successes != -1
                K = turn_successes.shape[1]
            elif "assistant_turn_rewards" in batch.batch and batch.batch["assistant_turn_rewards"].dim() == 2:
                K = batch.batch["assistant_turn_rewards"].shape[1]
            else:
                K = int(turn_ids.max().item()) + 1 if turn_ids.numel() > 0 else 0
            if K > 0:
                for k in range(K):
                    token_mask = response_mask & (turn_ids == k)
                    lengths = token_mask.sum(dim=1).to(dtype=torch.float32)
                    if reached is not None:
                        mask = reached[:, k]
                    else:
                        mask = lengths > 0
                    if torch.any(mask):
                        mean_k = torch.mean(lengths[mask])
                        metrics[f"step_output_length/mean_turn{k}_cond"] = mean_k.detach().item()
                    else:
                        metrics[f"step_output_length/mean_turn{k}_cond"] = float("nan")

    if (
        adv_estimator is not None
        and "turn_successes" in batch.batch
        and "uid" in batch.non_tensor_batch
        and batch.batch["turn_successes"].dim() == 2
        and batch.batch["turn_successes"].numel() > 0
    ):
        turn_successes = batch.batch["turn_successes"]
        method = None
        norm_mode = None
        clip_min = False
        if adv_estimator in _REWEIGHT_RMS_ESTIMATORS:
            method = "reweight"
            norm_mode = "rms"
        elif adv_estimator in _REWEIGHT_TOKENCOUNT_ESTIMATORS:
            method = "reweight"
            norm_mode = "tokencount"
        elif adv_estimator in _REWEIGHT_CLIP_MIN_ESTIMATORS:
            method = "reweight"
            norm_mode = "mean"
            clip_min = True
        elif adv_estimator in _REWEIGHT_ESTIMATORS:
            method = "reweight"
            norm_mode = "mean"
        elif adv_estimator in _FUTURE_ONLY_TOKENCOUNT_ESTIMATORS:
            method = "future_only"
            norm_mode = "tokencount"
        elif adv_estimator in _FUTURE_ONLY_ESTIMATORS:
            method = "future_only"
            norm_mode = "mean"
        elif adv_estimator in _INVCOUNT_TOKENCOUNT_ESTIMATORS:
            method = "invcount"
            norm_mode = "tokencount"
        elif adv_estimator in _INVCOUNT_ESTIMATORS:
            method = "invcount"
            norm_mode = "mean"

        if method is not None:
            with torch.no_grad():
                g = as_torch_index(batch.non_tensor_batch["uid"], device=turn_successes.device)
                if g.numel() == turn_successes.shape[0] and g.numel() > 0:
                    if method == "reweight":
                        w, step_mask = _compute_verk_weights(turn_successes, g)
                    elif method == "future_only":
                        w, step_mask = _compute_future_only_weights(turn_successes, g)
                    else:
                        w, step_mask = _compute_invcount_weights(turn_successes, g)

                    if norm_mode == "rms":
                        w_norm = _rms_normalize_weights(w, step_mask)
                    elif norm_mode == "tokencount":
                        if "assistant_turn_ids" not in batch.batch:
                            w_norm = None
                        else:
                            turn_ids = batch.batch["assistant_turn_ids"]
                            if turn_ids.dim() != 2 or turn_ids.shape != response_mask.shape:
                                w_norm = None
                            else:
                                turn_ids = turn_ids.to(device=w.device, dtype=torch.long)
                                K = w.shape[1]
                                if K > 0:
                                    turn_ids = turn_ids.clamp(min=0, max=K - 1)
                                w_norm = _token_count_mean_normalize_weights(
                                    w, g, turn_ids, response_mask.to(device=w.device)
                                )
                    else:
                        w_norm = _mean_normalize_weights(w, step_mask)
                        if clip_min:
                            min_step_weight = 0.25
                            if algo_config is not None:
                                min_step_weight = algo_config.get("grpo_verk_step_weight_clip_min", min_step_weight)
                            if min_step_weight is None:
                                min_step_weight = 0.25
                            w_norm = w_norm.clamp_min(float(min_step_weight))

                    if w_norm is not None:
                        label = method
                        if norm_mode == "tokencount":
                            label = f"{method}_tokencount"
                        elif norm_mode == "rms":
                            label = f"{method}_rms"
                        elif clip_min:
                            label = f"{method}_clip_min"
                        reached = turn_successes != -1
                        for k in range(w_norm.shape[1]):
                            mask = reached[:, k]
                            if torch.any(mask):
                                weights_k = w_norm[g, k]
                                metrics[f"step_weight/{label}/turn{k}_cond"] = weights_k[mask].mean().detach().item()
                            else:
                                metrics[f"step_weight/{label}/turn{k}_cond"] = float("nan")

    return metrics


def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:
    """
    Computes timing metrics for different processing stages in PPO training.

    This function calculates both raw timing metrics (in seconds) and per-token timing metrics
    (in milliseconds) for various processing stages like generation, reference computation,
    value computation, advantage computation, and model updates.

    Args:
        batch: A DataProto object containing batch data with responses and attention masks.
        timing_raw: A dictionary mapping stage names to their execution times in seconds.

    Returns:
        A dictionary containing:
            - timing_s/{name}: Raw timing in seconds for each stage
            - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage

    Note:
        Different stages use different token counts for normalization:
        - "gen" uses only response tokens
        - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens
          (prompt + response)
    """
    response_info = _compute_response_info(batch)
    num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
    num_response_tokens = torch.sum(response_info["response_length"]).item()
    num_overall_tokens = num_prompt_tokens + num_response_tokens

    num_tokens_of_section = {
        "gen": num_response_tokens,
        **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
    }

    return {
        **{f"timing_s/{name}": value for name, value in timing_raw.items()},
        **{
            f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
        },
    }


def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:
    """
    Computes throughput metrics for PPO training.

    This function calculates performance metrics related to token processing speed,
    including the total number of tokens processed, time per step, and throughput
    (tokens per second per GPU).

    Args:
        batch: A DataProto object containing batch data with meta information about token counts.
        timing_raw: A dictionary mapping stage names to their execution times in seconds.
                   Must contain a "step" key with the total step time.
        n_gpus: Number of GPUs used for training.

    Returns:
        A dictionary containing:
            - perf/total_num_tokens: Total number of tokens processed in the batch
            - perf/time_per_step: Time taken for the step in seconds
            - perf/throughput: Tokens processed per second per GPU

    Note:
        The throughput is calculated as total_tokens / (time * n_gpus) to normalize
        across different GPU counts.
    """
    total_num_tokens = sum(batch.meta_info["global_token_num"])
    time = timing_raw["step"]
    # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)
    # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus),
    # f'Theoretical TFLOPs/s/GPU​': promised_flops,
    return {
        "perf/total_num_tokens": total_num_tokens,
        "perf/time_per_step": time,
        "perf/throughput": total_num_tokens / (time * n_gpus),
    }


def bootstrap_metric(
    data: list[Any],
    subset_size: int,
    reduce_fns: list[Callable[[np.ndarray], float]],
    n_bootstrap: int = 1000,
    seed: int = 42,
) -> list[tuple[float, float]]:
    """
    Performs bootstrap resampling to estimate statistics of metrics.

    This function uses bootstrap resampling to estimate the mean and standard deviation
    of metrics computed by the provided reduction functions on random subsets of the data.

    Args:
        data: List of data points to bootstrap from.
        subset_size: Size of each bootstrap sample.
        reduce_fns: List of functions that compute a metric from a subset of data.
        n_bootstrap: Number of bootstrap iterations. Defaults to 1000.
        seed: Random seed for reproducibility. Defaults to 42.

    Returns:
        A list of tuples, where each tuple contains (mean, std) for a metric
        corresponding to each reduction function in reduce_fns.

    Example:
        >>> data = [1, 2, 3, 4, 5]
        >>> reduce_fns = [np.mean, np.max]
        >>> bootstrap_metric(data, 3, reduce_fns)
        [(3.0, 0.5), (4.5, 0.3)]  # Example values
    """
    np.random.seed(seed)

    bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]
    for _ in range(n_bootstrap):
        bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)
        bootstrap_data = [data[i] for i in bootstrap_idxs]
        for i, reduce_fn in enumerate(reduce_fns):
            bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))
    return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]


def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:
    """
    Calculate a value based on majority voting.

    This function identifies the most common value for a specified vote key
    in the data, then returns the corresponding value for that majority vote.

    Args:
        data: List of dictionaries, where each dictionary contains both vote_key and val_key.
        vote_key: The key in each dictionary used for voting/counting.
        val_key: The key in each dictionary whose value will be returned for the majority vote.

    Returns:
        The value associated with the most common vote.

    Example:
        >>> data = [
        ...     {"pred": "A", "val": 0.9},
        ...     {"pred": "B", "val": 0.8},
        ...     {"pred": "A", "val": 0.7}
        ... ]
        >>> calc_maj_val(data, vote_key="pred", val_key="val")
        0.9  # Returns the first "val" for the majority vote "A"
    """
    vote2vals = defaultdict(list)
    for d in data:
        vote2vals[d[vote_key]].append(d[val_key])

    vote2cnt = {k: len(v) for k, v in vote2vals.items()}
    maj_vote = max(vote2cnt, key=vote2cnt.get)

    maj_val = vote2vals[maj_vote][0]

    return maj_val


def process_validation_metrics(
    data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
) -> dict[str, dict[str, dict[str, float]]]:
    """
    Process validation metrics into a structured format with statistical analysis.

    This function organizes validation metrics by data source and prompt, then computes
    various statistical measures including means, standard deviations, best/worst values,
    and majority voting results. It also performs bootstrap sampling to estimate statistics
    for different sample sizes.

    Args:
        data_sources: List of data source identifiers for each sample.
        sample_uids: List of sample uids corresponding to each sample.
        infos_dict: Dictionary mapping variable names to lists of values for each sample.
        seed: Random seed for bootstrap sampling. Defaults to 42.

    Returns:
        A nested dictionary with the structure:
        {
            data_source: {
                variable_name: {
                    metric_name: value
                }
            }
        }

        Where metric_name includes:
        - "mean@N": Mean value across N samples
        - "std@N": Standard deviation across N samples
        - "best@N/mean": Mean of the best values in bootstrap samples of size N
        - "best@N/std": Standard deviation of the best values in bootstrap samples
        - "worst@N/mean": Mean of the worst values in bootstrap samples
        - "worst@N/std": Standard deviation of the worst values in bootstrap samples
        - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
        - "maj@N/std": Standard deviation of majority voting results (if "pred" exists)

    Example:
        >>> data_sources = ["source1", "source1", "source2"]
        >>> sample_uids = ["uid1", "uid1", "uid2"]
        >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
        >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
        >>> # result will contain statistics for each data source and variable
    """
    # Group metrics by data source, prompt and variable
    data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for sample_idx, data_source in enumerate(data_sources):
        uid = sample_uids[sample_idx]
        var2vals = data_src2uid2var2vals[data_source][uid]
        for var_name, var_vals in infos_dict.items():
            var2vals[var_name].append(var_vals[sample_idx])

    # Calculate metrics for each group
    data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for data_source, uid2var2vals in data_src2uid2var2vals.items():
        for uid, var2vals in uid2var2vals.items():
            for var_name, var_vals in var2vals.items():
                if isinstance(var_vals[0], str):
                    continue

                metric = {}
                n_resps = len(var_vals)
                metric[f"mean@{n_resps}"] = np.mean(var_vals)

                if n_resps > 1:
                    metric[f"std@{n_resps}"] = np.std(var_vals)

                    ns = []
                    n = 2
                    while n < n_resps:
                        ns.append(n)
                        n *= 2
                    ns.append(n_resps)

                    for n in ns:
                        [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
                            data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
                        )
                        metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
                        metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
                        if var2vals.get("pred", None) is not None:
                            vote_data = [
                                {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
                            ]
                            [(maj_n_mean, maj_n_std)] = bootstrap_metric(
                                data=vote_data,
                                subset_size=n,
                                reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
                                seed=seed,
                            )
                            metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std

                data_src2uid2var2metric[data_source][uid][var_name] = metric

    # Aggregate metrics across uids
    data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for data_source, uid2var2metric in data_src2uid2var2metric.items():
        for uid, var2metric in uid2var2metric.items():
            for var_name, metric in var2metric.items():
                for metric_name, metric_val in metric.items():
                    data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)

    data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
        for var_name, metric2uid_vals in var2metric2uid_vals.items():
            for metric_name, uid_vals in metric2uid_vals.items():
                data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)

    return data_src2var2metric2val
