# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Metrics related to the PPO trainer.
"""
import re
from collections import defaultdict
from functools import partial
from typing import Any, Callable

import numpy as np
import torch

from verl import DataProto
from verl.utils.import_utils import deprecated


@deprecated("verl.utils.metric.reduce_metrics")
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
    """
    Aggregate lists of metric values into single summary statistics.

    This function takes a dictionary where each key corresponds to a metric name
    and each value is a list of measurements. It computes a representative value
    for each metric (for example, the mean or another reduction method).

    Args:
        metrics (dict[str, list[Any]]):
            A dictionary mapping metric names to lists of metric values.
            Example: {"accuracy": [0.8, 0.9, 0.85], "loss": [0.4, 0.3, 0.35]}.

    Returns:
        dict[str, Any]:
            A dictionary containing the reduced metrics, where each metric name
            maps to a single aggregated value.
            Example: {"accuracy": 0.85, "loss": 0.35}.
    """
    from verl.utils.metric import reduce_metrics

    return reduce_metrics(metrics)


def _compute_response_info(batch: DataProto) -> dict[str, Any]:
    """
    Computes information about prompts and responses from a batch.

    This is an internal helper function that extracts masks and lengths for prompts and responses.

    Args:
        batch: A DataProto object containing batch data with responses and attention masks.

    Returns:
        A dictionary containing:
            - response_mask: Attention mask for the response tokens
            - prompt_length: Tensor of prompt lengths for each item in the batch
            - response_length: Tensor of response lengths for each item in the batch
    """
    response_length = batch.batch["responses"].shape[-1]

    prompt_mask = batch.batch["attention_mask"][:, :-response_length]
    response_mask = batch.batch["attention_mask"][:, -response_length:]

    prompt_length = prompt_mask.sum(-1).float()
    response_length = response_mask.sum(-1).float()  # (batch_size,)

    return dict(
        response_mask=response_mask,
        prompt_length=prompt_length,
        response_length=response_length,
    )


def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:
    """
    Compute summary metrics from a PPO training batch.

    This aggregates token-level quantities (scores, rewards, advantages, returns, values)
    into sequence-level statistics and reports descriptive stats (mean / max / min).
    It also reports prompt/response length distributions, aborted sample ratio,
    and optional task-specific extras stored in `batch.non_tensor_batch`.

    Args:
        batch: A DataProto carrying:
            - batch.batch["token_level_scores"]: (B, T_resp) per-token sequence scores
            - batch.batch["token_level_rewards"]: (B, T_resp) per-token rewards
            - batch.batch["advantages"]: (B, T_resp) per-token advantages
            - batch.batch["returns"]: (B, T_resp) per-token returns
            - batch.batch["values"]: (B, T_resp) per-token value estimates (if `use_critic=True`)
            - batch.batch["responses"]: (B, T_resp) response token ids (used for max length)
            - batch.batch["attention_mask"]: (B, T_total) 1 for real tokens, 0 for padding
            - batch.batch["response_mask"]: (B, T_resp) 1 on valid response tokens
        use_critic: If True, include value-function stats and explained variance.

    Returns:
        Dict[str, float]: metric name → scalar.
            Key groups include:
            - critic/score/*, critic/rewards/*, critic/advantages/*, critic/returns/*
            - critic/values/* and critic/vf_explained_var (only if `use_critic=True`)
            - response_length/* (overall) and response_length_non_aborted/* (exclude zeros)
            - response/aborted_ratio
            - prompt_length/*
            - num_turns/*, tool_call_counts/* (if present in non-tensor payload)
            - critic/acc/*, critic/code_reward/*, critic/acc_add_code_reward/*,
              critic/format/* when those arrays are provided in non-tensor payload.
    """
    # ---- Sequence-level aggregates -------------------------------------------------
    # Sum token-level quantities over the response time dimension to get per-sequence values.
    sequence_score = batch.batch["token_level_scores"].sum(-1)
    sequence_reward = batch.batch["token_level_rewards"].sum(-1)

    advantages = batch.batch["advantages"]
    returns = batch.batch["returns"]

    max_response_length = batch.batch["responses"].shape[-1]

    prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
    response_mask = batch.batch["response_mask"].bool()

    max_prompt_length = prompt_mask.size(-1)

    response_info = _compute_response_info(batch)
    prompt_length = response_info["prompt_length"]
    response_length = response_info["response_length"]

    aborted_mask = (response_length == 0).bool()
    non_aborted_mask = ~aborted_mask

    non_aborted_sequence_score = sequence_score[non_aborted_mask]
    non_aborted_sequence_reward = sequence_reward[non_aborted_mask]

    score_mean = torch.mean(non_aborted_sequence_score).detach().item()
    score_max = torch.max(non_aborted_sequence_score).detach().item()
    score_min = torch.min(non_aborted_sequence_score).detach().item()

    reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
    reward_max = torch.max(non_aborted_sequence_reward).detach().item()
    reward_min = torch.min(non_aborted_sequence_reward).detach().item()

    valid_adv = torch.masked_select(advantages, response_mask)
    valid_returns = torch.masked_select(returns, response_mask)

    if use_critic:
        values = batch.batch["values"]
        valid_values = torch.masked_select(values, response_mask)
        return_diff_var = torch.var(valid_returns - valid_values)
        return_var = torch.var(valid_returns)

    # Aborted samples and non-aborted response length statistics
    # response_length_non_aborted/*: statistics computed on non-aborted samples only
    aborted_ratio = torch.mean(aborted_mask.float()).detach().item()

    non_aborted_response_length = response_length[non_aborted_mask]
    if non_aborted_response_length.numel() > 0:
        non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
        non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
        non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
        non_aborted_response_length_clip_ratio = (
            torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
        )
    else:
        raise ValueError("All samples are aborted, this should not happen.")

    metrics = {
        # score
        "critic/score/mean": score_mean,
        "critic/score/max": score_max,
        "critic/score/min": score_min,
        # reward
        "critic/rewards/mean": reward_mean,
        "critic/rewards/max": reward_max,
        "critic/rewards/min": reward_min,
        # adv
        "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
        "critic/advantages/max": torch.max(valid_adv).detach().item(),
        "critic/advantages/min": torch.min(valid_adv).detach().item(),
        # returns
        "critic/returns/mean": torch.mean(valid_returns).detach().item(),
        "critic/returns/max": torch.max(valid_returns).detach().item(),
        "critic/returns/min": torch.min(valid_returns).detach().item(),
        **(
            {
                # values
                "critic/values/mean": torch.mean(valid_values).detach().item(),
                "critic/values/max": torch.max(valid_values).detach().item(),
                "critic/values/min": torch.min(valid_values).detach().item(),
                # vf explained var
                "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
            }
            if use_critic
            else {}
        ),
        # response length
        "response_length/mean": torch.mean(response_length).detach().item(),
        "response_length/max": torch.max(response_length).detach().item(),
        "response_length/min": torch.min(response_length).detach().item(),
        "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
        .detach()
        .item(),
        # response length (non-aborted only)
        # These statistics exclude aborted samples to avoid skew from zeros
        "response_length_non_aborted/mean": non_aborted_response_length_mean,
        "response_length_non_aborted/max": non_aborted_response_length_max,
        "response_length_non_aborted/min": non_aborted_response_length_min,
        "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
        # aborted ratio
        # Fraction of samples whose response length is zero
        "response/aborted_ratio": aborted_ratio,
        # prompt length
        "prompt_length/mean": torch.mean(prompt_length).detach().item(),
        "prompt_length/max": torch.max(prompt_length).detach().item(),
        "prompt_length/min": torch.min(prompt_length).detach().item(),
        "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
    }

    # multi-turn conversation
    if "__num_turns__" in batch.non_tensor_batch:
        num_turns = batch.non_tensor_batch["__num_turns__"]
        metrics["num_turns/min"] = num_turns.min()
        metrics["num_turns/max"] = num_turns.max()
        metrics["num_turns/mean"] = num_turns.mean()

    if "tool_call_counts" in batch.non_tensor_batch:
        tool_call_counts = batch.non_tensor_batch["tool_call_counts"]
        metrics["tool_call_counts/min"] = tool_call_counts.min()
        metrics["tool_call_counts/max"] = tool_call_counts.max()
        metrics["tool_call_counts/mean"] = tool_call_counts.mean()
    if "acc" in batch.non_tensor_batch:
        acc = batch.non_tensor_batch["acc"]
        metrics.update(
            {
                # answer score
                "critic/acc/mean": np.mean(acc),
                "critic/acc/max": np.max(acc),
                "critic/acc/min": np.min(acc),
            }
        )

    if "code_reward" in batch.non_tensor_batch:
        code_reward = batch.non_tensor_batch["code_reward"]
        exceed_code_numbers_positive_lists = [ c_m for c_m in code_reward if c_m >= 1.0]
        exceed_code_numbers_positive_ratio = len(exceed_code_numbers_positive_lists) / len(code_reward)

        metrics.update(
            {
                # answer score
                "critic/code_reward/mean": np.mean(code_reward),
                "critic/code_reward/max": np.max(code_reward),
                "critic/code_reward/min": np.min(code_reward),
                "critic/code_reward/exceed_code_numbers_positive_ratio": exceed_code_numbers_positive_ratio,

            }
        )

    if "acc_add_code_reward" in batch.non_tensor_batch:
        acc_add_code_reward = batch.non_tensor_batch["acc_add_code_reward"]
        metrics.update(
            {
                # answer score
                "critic/acc_add_code_reward/mean": np.mean(acc_add_code_reward),
                "critic/acc_add_code_reward/max": np.max(acc_add_code_reward),
                "critic/acc_add_code_reward/min": np.min(acc_add_code_reward),
            }
        )

    if "format" in batch.non_tensor_batch:
        format_score = batch.non_tensor_batch["format"]
        metrics.update(
            {
                # format score
                "critic/format/mean": np.mean(format_score),
                "critic/format/max": np.max(format_score),
                "critic/format/min": np.min(format_score),
            }
        )

    return metrics


def compute_val_data_metrics(batch: DataProto) -> dict[str, Any]:
    """
    Compute lightweight **validation** metrics for a PPO batch.

    This function summarizes a few high-signal quantities on the validation split,
    mirroring the naming used in training but namespaced under ``"val/"``. It
    intentionally omits critic/value-specific diagnostics to keep validation
    inexpensive and focused.

    What it computes:
      - Sequence-level means:
        * ``val/critic/score/mean`` — mean of per-sequence scores
        * ``val/critic/rewards/mean`` — mean of per-sequence rewards
      - Response length stats (in tokens):
        * ``val/response_length/mean``, ``/max``, ``/min``
        * ``val/response_length/clip_ratio`` — fraction of responses that exactly
          hit the allocated maximum response length
      - Prompt length stats (in tokens):
        * ``val/prompt_length/mean``, ``/max``, ``/min``
        * ``val/prompt_length/clip_ratio`` — fraction of prompts that fill the
          allocated prompt window
      - (Optional, if provided) multi-turn metadata:
        * ``val/num_turns/mean``, ``/max``, ``/min``

    Conventions:
      - *Sequence score* and *sequence reward* are computed by summing their
        per-token values over the response time dimension (one scalar per sample).
      - *Clip ratio* is reported as a fraction in ``[0, 1]``.

    Args:
        batch: A ``DataProto`` with at least:
            - ``batch["token_level_scores"]``: (B, T_resp) per-token scores
            - ``batch["token_level_rewards"]``: (B, T_resp) per-token rewards
            - ``batch["responses"]``: (B, T_resp) used to infer max response length
            - ``batch["attention_mask"]``: (B, T_total) 1=real token, 0=pad
            - ``batch["response_mask"]``: (B, T_resp) valid response tokens (not
              used directly here, kept for parity)
          Optionally, ``non_tensor_batch["__num_turns__"]``: array-like per-sample turn counts.

    Returns:
        dict[str, float]: Metric name → scalar value as described above.
    """
    # --- Aggregate sequence-level quantities (sum over response tokens) ----------

    sequence_score = batch.batch["token_level_scores"].sum(-1)
    sequence_reward = batch.batch["token_level_rewards"].sum(-1)


    max_response_length = batch.batch["responses"].shape[-1]

    prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
    response_mask = batch.batch["response_mask"].bool()

    max_prompt_length = prompt_mask.size(-1)

    response_info = _compute_response_info(batch)
    prompt_length = response_info["prompt_length"]
    response_length = response_info["response_length"]

    metrics = {
        # score
        "val/critic/score/mean": torch.mean(sequence_score).detach().item(),
        # reward
        "val/critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
        # response length
        "val/response_length/mean": torch.mean(response_length).detach().item(),
        "val/response_length/max": torch.max(response_length).detach().item(),
        "val/response_length/min": torch.min(response_length).detach().item(),
        "val/response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
        .detach()
        .item(),
        # prompt length
        "val/prompt_length/mean": torch.mean(prompt_length).detach().item(),
        "val/prompt_length/max": torch.max(prompt_length).detach().item(),
        "val/prompt_length/min": torch.min(prompt_length).detach().item(),
        "val/prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
    }

    # multi-turn conversation
    if "__num_turns__" in batch.non_tensor_batch:
        num_turns = batch.non_tensor_batch["__num_turns__"]
        metrics["val/num_turns/min"] = num_turns.min()
        metrics["val/num_turns/max"] = num_turns.max()
        metrics["val/num_turns/mean"] = num_turns.mean()

    return metrics



def count_code_blocks(text: str) -> int:
    """
    Count the number of code blocks in a given text.

    Args:
        text (str): The input text that may contain Markdown-style code blocks.

    Returns:
        int: The number of code blocks found in the text.
    """
    pattern = r"<code>(.*?)</code>"
    matches = re.findall(pattern, text, flags=re.S | re.I)
    return len(matches)

def statistic_text_code_info(inputs, outputs, uids, rewards, max_turns_num=0, save_prefix="code_info"):
    """
    Collect and compute statistics about text-code interactions.

    This function analyzes pairs of input and output texts (possibly from model conversations),
    calculates reward-related metrics, and optionally saves summarized information for further analysis.

    Args:
        inputs (list[str]): A list of input texts (e.g., user queries or prompts).
        outputs (list[str]): A list of generated outputs corresponding to each input.
        uids (list[str]): Unique identifiers for each input-output pair.
        rewards (list[float]): A list of reward scores associated with each response.
        max_turns_num (int, optional): The maximum number of dialogue turns to consider.
            Defaults to 0, meaning no limit.
        save_prefix (str, optional): The file prefix used when saving statistical results.
            Defaults to "code_info".

    Returns:
        dict: A dictionary containing aggregated statistics, such as average reward,
        counts of successful generations, and code-related ratios.
    """

    # Map each uid (question/session) to a list of its items (one item per sample/attempt).
    # Requires: from collections import defaultdict
    questions_dict = defaultdict(list)

    # Flat list of per-sample records for convenience.
    temp_items = []

    # Total number of samples to iterate over.
    batch_size = len(inputs)

    # Track number of code blocks detected in each output (for average computation).
    code_numbers_list = []

    # Iterate over all samples and build normalized item dicts.
    for idx in range(batch_size):
        input = inputs[idx]
        output = outputs[idx]
        uid = uids[idx]
        reward = rewards[idx]

        # Count code blocks in the generated output.
        # Requires a helper: count_code_blocks(output) -> int
        code_numbers = count_code_blocks(output)
        code_numbers_list.append(code_numbers)

        # Assemble a lightweight record for this sample.
        temp = {}
        temp['idx'] = idx
        temp['input'] = input
        temp['output'] = output
        temp['uid'] = uid
        temp['reward'] = reward
        temp['code_numbers'] = code_numbers

        # Flag whether this output contains code.
        if code_numbers > 0:
            temp['use_code'] = True
        else:
            temp['use_code'] = False

        # Add to the flat list and also group by uid.
        temp_items.append(temp)
        questions_dict[uid].append(temp)

    # Buckets split by "contains code" vs "text-only", along with reward lists.
    text_datas = []
    text_datas_acc = []
    code_datas = []
    code_datas_acc = []

    # Track samples that exceed (or are close to) the max turn threshold.
    exceed_max_turns_list = []

    for item in temp_items:
        code_numbers = item['code_numbers']
        reward = item['reward']

        # Route items into code vs text buckets and collect rewards.
        if code_numbers > 0:
            code_datas.append(item)
            code_datas_acc.append(reward)
        else:
            text_datas.append(item)
            text_datas_acc.append(reward)

        # Heuristic: consider "exceeding" the limit if code blocks >= (max_turns_num - 2).
        # Note: with max_turns_num == 0, threshold is -2, so everything qualifies.
        if code_numbers >= (max_turns_num - 2):
            exceed_max_turns_list.append(item)

    # Global ratios/averages at the sample level.
    code_ratio = len(code_datas) / batch_size
    code_numbers_avg = sum(code_numbers_list) / len(code_numbers_list)
    exceed_max_turns_ratio = len(exceed_max_turns_list) / batch_size

    # Averages of rewards for code vs text subsets.
    # (Variable names keep original spelling to avoid changing external usage.)
    ues_code_datas_acc_avg = 0
    ues_text_datas_acc_avg = 0
    if len(code_datas_acc) > 0:
        ues_code_datas_acc_avg = sum(code_datas_acc) / len(code_datas_acc)

    if len(text_datas_acc) > 0:
        ues_text_datas_acc_avg = sum(text_datas_acc) /len(text_datas_acc)

    # Per-question (per-uid) analysis: classify whether a question was answered
    # always with code, always with text, or a mix across attempts.
    question_use_code_list = []
    question_use_text_list = []
    question_use_mix_list = []

    # For mixed cases, further categorize by reward patterns.
    mix_code0_text1 = []  # code sum == 0 and text sum > 0
    mix_code1_text0 = []  # code sum > 0 and text sum == 0
    mix_code1_text1 = []  # both code and text have positive total rewards
    mix_code_than_text = []  # code total reward >= text total reward

    question_nums = len(questions_dict)

    # Walk through each question/session (grouped by uid).
    for k, v in questions_dict.items():
        sample_n = len(v)

        # Separate attempts within this uid into code vs text, and collect rewards.
        use_code_kk = []
        use_code_acc_kk = []
        use_text_kk = []
        use_text_acc_kk = []

        for kk in v:
            is_code = kk['use_code']
            reds = kk['reward']
            if is_code == True:
                use_code_kk.append(kk)
                use_code_acc_kk.append(reds)
            else:
                use_text_kk.append(kk)
                use_text_acc_kk.append(reds)

        # Classify the question based on whether all / none / some attempts used code.
        if len(use_code_kk) == sample_n:
            question_use_code_list.append(k)
        elif len(use_code_kk) == 0:
            question_use_text_list.append(k)
        else:
            question_use_mix_list.append(k)

            # For mixed questions, compare total rewards from code vs text attempts.
            total_code_acc = sum(use_code_acc_kk)
            total_text_acc = sum(use_text_acc_kk)

            if total_code_acc == 0 and total_text_acc > 0:
                mix_code0_text1.append(k)

            if total_code_acc > 0 and total_text_acc == 0:
                mix_code1_text0.append(k)

            if total_code_acc > 0 and total_text_acc > 0:
                mix_code1_text1.append(k)

            if total_code_acc >= total_text_acc:
                mix_code_than_text.append(k)

    # Ratios at the question (uid) level.
    question_use_all_code_ratio = len(question_use_code_list) / question_nums
    question_use_all_text_ratio = len(question_use_text_list) / question_nums
    question_use_mix_code_text_ratio = len(question_use_mix_list) / question_nums

    # Mixed-case breakdown ratios (guard against division by zero).
    if len(question_use_mix_list) > 0:
        mix_code0_text1_ratio = len(mix_code0_text1) / len(question_use_mix_list)
        mix_code1_text0_ratio = len(mix_code1_text0) / len(question_use_mix_list)
        mix_code1_text1_ratio = len(mix_code1_text1) / len(question_use_mix_list)
        mix_code_than_text_ratio = len(mix_code_than_text) / len(question_use_mix_list)
    else:
        mix_code0_text1_ratio = 0
        mix_code1_text0_ratio = 0
        mix_code1_text1_ratio = 0
        mix_code_than_text_ratio = 0

    # Aggregate all metrics into a single dict.
    # Note: key names keep original "metrix_code" and "ues_*" spellings for compatibility.
    metrix_code = {
        f"{save_prefix}/code_ratio": code_ratio,
        f"{save_prefix}/code_numbers_avg": code_numbers_avg,
        f"{save_prefix}/exceed_max_turns_ratio": exceed_max_turns_ratio,
        f"{save_prefix}/ues_code_datas_acc_avg": ues_code_datas_acc_avg,
        f"{save_prefix}/ues_text_datas_acc_avg": ues_text_datas_acc_avg,
        f"{save_prefix}/question_use_all_code_ratio": question_use_all_code_ratio,
        f"{save_prefix}/question_use_all_text_ratio": question_use_all_text_ratio,
        f"{save_prefix}/question_use_mix_code_text_ratio": question_use_mix_code_text_ratio,
        f"{save_prefix}/mix_code0_text1_ratio": mix_code0_text1_ratio,
        f"{save_prefix}/mix_code1_text0_ratio": mix_code1_text0_ratio,
        f"{save_prefix}/mix_code1_text1_ratio": mix_code1_text1_ratio,
        f"{save_prefix}/mix_code_than_text_ratio": mix_code_than_text_ratio,
    }

    return metrix_code


def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:
    """
    Compute raw and per-token timing metrics for PPO training stages.

    This helper consumes:
      - a training `batch` (containing response/prompt lengths via `_compute_response_info`)
      - a mapping `timing_raw` of stage names -> wall-clock seconds

    It returns a flat dictionary that includes:
      - "timing_s/{name}": the raw wall time (seconds) for each stage present in `timing_raw`
      - "timing_per_token_ms/{name}": the same timings normalized by token count (milliseconds/token),
        using stage-specific token denominators (see Notes).

    Notes on normalization:
      - Stage "gen" (generation) is normalized by *response tokens only*.
      - Stages "ref", "values", "adv", "update_critic", "update_actor" are normalized by
        *overall tokens* (prompt + response).
      - Only stages that are present in both the known denominator map and `timing_raw`
        get per-token metrics.

    Assumptions / caveats:
      - `_compute_response_info(batch)` returns tensors with keys "prompt_length" and "response_length".
      - Token counts are strictly positive for any stage you want per-token metrics for; otherwise
        a division by zero would occur.
      - `timing_raw` values are durations in seconds (floats).
    """

    # Derive per-sample prompt/response lengths from the batch.
    # Expected shape: each of these is a tensor of per-sample token counts.
    response_info = _compute_response_info(batch)

    # Sum across the batch to get total tokens of each type.
    # `.item()` converts the 0-D tensor result into a native Python `int` for arithmetic below.
    num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
    num_response_tokens = torch.sum(response_info["response_length"]).item()

    # Overall tokens are defined as prompt + response. This is used to normalize
    # every stage except pure generation ("gen").
    num_overall_tokens = num_prompt_tokens + num_response_tokens

    # For each stage, specify which token count should be used as the denominator.
    # - "gen" → response-only tokens
    # - all other listed stages → overall tokens (prompt + response)
    # NOTE: Stages not listed here will not receive per-token metrics unless added.
    num_tokens_of_section = {
        "gen": num_response_tokens,
        **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
    }

    # Build the returned metrics dictionary in two parts and merge with dict unpacking:
    # (1) Raw timings in seconds, prefixed with "timing_s/"
    # (2) Per-token timings in milliseconds/token, prefixed with "timing_per_token_ms/"
    #     Only computed for stage names that exist in BOTH `num_tokens_of_section` and `timing_raw`.
    #     The conversion is: seconds * 1000 / tokens.
    return {
        # (1) Raw per-stage timings (seconds)
        **{f"timing_s/{name}": value for name, value in timing_raw.items()},

        # (2) Normalized per-token timings (ms/token)
        **{
            f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
        },
    }


def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:
    """
    Computes throughput metrics for PPO training.

    This function calculates performance metrics related to token processing speed,
    including the total number of tokens processed, time per step, and throughput
    (tokens per second per GPU).

    Args:
        batch: A DataProto object containing batch data with meta information about token counts.
        timing_raw: A dictionary mapping stage names to their execution times in seconds.
                   Must contain a "step" key with the total step time.
        n_gpus: Number of GPUs used for training.

    Returns:
        A dictionary containing:
            - perf/total_num_tokens: Total number of tokens processed in the batch
            - perf/time_per_step: Time taken for the step in seconds
            - perf/throughput: Tokens processed per second per GPU

    Note:
        The throughput is calculated as total_tokens / (time * n_gpus) to normalize
        across different GPU counts.
    """

    # Aggregate the total number of tokens processed in the current batch.
    # Expects `batch.meta_info["global_token_num"]` to be an iterable (e.g., list/array)
    # of per-sample token counts that already include both prompt and response tokens
    # (or whatever definition "global" implies in your pipeline).
    # The sum gives the total token budget for this training step across all samples.
    total_num_tokens = sum(batch.meta_info["global_token_num"])

    # Retrieve the wall-clock time spent on the *entire* training step.
    # `timing_raw` is a mapping of stage -> seconds; we specifically require a "step" key
    # that measures the end-to-end step duration, including forward/backward/optimizer, etc.
    # Assumes the key exists; otherwise a KeyError will be raised.
    time = timing_raw["step"]

    # (Optional / placeholder)
    # If FLOPs accounting is available, one could estimate actual vs. theoretical throughput:
    #   estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)
    #   'Actual TFLOPs/s/GPU​': estimated_flops / n_gpus,
    #   'Theoretical TFLOPs/s/GPU​': promised_flops,
    # Left commented because the helper and variables are not defined in this scope.

    # Return a flat dictionary with three key metrics:
    #  - total tokens seen this step,
    #  - total wall time of the step (seconds),
    #  - token throughput normalized per GPU: tokens / (seconds * number_of_gpus).
    # Assumptions for correctness:
    #  * `n_gpus` > 0 (avoid division by zero).
    #  * `time` > 0 (avoid division by zero and non-sensical throughput).
    #  * `total_num_tokens` >= 0 (non-negative workload).
    return {
        "perf/total_num_tokens": total_num_tokens,
        "perf/time_per_step": time,
        "perf/throughput": total_num_tokens / (time * n_gpus),
    }



def bootstrap_metric(
    data: list[Any],
    subset_size: int,
    reduce_fns: list[Callable[[np.ndarray], float]],
    n_bootstrap: int = 1000,
    seed: int = 42,
) -> list[tuple[float, float]]:
    """
    Performs bootstrap resampling to estimate statistics of metrics.

    This function uses bootstrap resampling to estimate the mean and standard deviation
    of metrics computed by the provided reduction functions on random subsets of the data.

    Args:
        data: List of data points to bootstrap from.
        subset_size: Size of each bootstrap sample.
        reduce_fns: List of functions that compute a metric from a subset of data.
        n_bootstrap: Number of bootstrap iterations. Defaults to 1000.
        seed: Random seed for reproducibility. Defaults to 42.

    Returns:
        A list of tuples, where each tuple contains (mean, std) for a metric
        corresponding to each reduction function in reduce_fns.

    Example:
        >>> data = [1, 2, 3, 4, 5]
        >>> reduce_fns = [np.mean, np.max]
        >>> bootstrap_metric(data, 3, reduce_fns)
        [(3.0, 0.5), (4.5, 0.3)]  # Example values
    """
    # Set the global NumPy RNG seed to make the bootstrap procedure reproducible.
    # NOTE: This affects all subsequent uses of np.random in the current process.
    np.random.seed(seed)

    # Prepare one empty list per metric/reduction function. Each inner list will collect
    # the bootstrap values of that metric across n_bootstrap iterations.
    # Example: if reduce_fns = [np.mean, np.max], we create [[], []].
    bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]

    # Perform the bootstrap loop n_bootstrap times.
    for _ in range(n_bootstrap):
        # Draw 'subset_size' indices with replacement from the range [0, len(data)-1].
        # With replacement ⇒ the same original element can appear multiple times in a sample.
        # This is the defining property of the bootstrap.
        bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)

        # Materialize the bootstrap sample by indexing into 'data' with the drawn indices.
        # We keep it as a Python list here. Many NumPy reductions accept lists directly, but
        # if a reduce_fn strictly expects np.ndarray, it will need to handle the conversion.
        bootstrap_data = [data[i] for i in bootstrap_idxs]

        # Compute each requested metric on this bootstrap sample.
        for i, reduce_fn in enumerate(reduce_fns):
            # Apply the reduction function to the bootstrap sample and append the result
            # to the corresponding accumulator list. The function should return a scalar float.
            # If your reduce_fn requires an ndarray, it can internally call np.asarray(bootstrap_data).
            bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))

    # After collecting bootstrap values for each metric, summarize them by computing:
    #  - the mean of the bootstrap distribution (point estimate),
    #  - the standard deviation of the bootstrap distribution (uncertainty).
    # NOTE: np.std uses population std (ddof=0) by default; change ddof if you want an
    #       unbiased estimator for small samples.
    return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]



def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:
    """
    Calculate a value based on majority voting.

    This function identifies the most common value for a specified vote key
    in the data, then returns the corresponding value for that majority vote.

    Args:
        data: List of dictionaries, where each dictionary contains both vote_key and val_key.
        vote_key: The key in each dictionary used for voting/counting.
        val_key: The key in each dictionary whose value will be returned for the majority vote.

    Returns:
        The value associated with the most common vote.

    Example:
        >>> data = [
        ...     {"pred": "A", "val": 0.9},
        ...     {"pred": "B", "val": 0.8},
        ...     {"pred": "A", "val": 0.7}
        ... ]
        >>> calc_maj_val(data, vote_key="pred", val_key="val")
        0.9  # Returns the first "val" for the majority vote "A"
    """
    vote2vals = defaultdict(list)
    for d in data:
        vote2vals[d[vote_key]].append(d[val_key])

    vote2cnt = {k: len(v) for k, v in vote2vals.items()}
    maj_vote = max(vote2cnt, key=vote2cnt.get)

    maj_val = vote2vals[maj_vote][0]

    return maj_val


def process_validation_metrics(
    data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
) -> dict[str, dict[str, dict[str, float]]]:
    """
    Process validation metrics into a structured format with statistical analysis.

    This function organizes validation metrics by data source and prompt, then computes
    various statistical measures including means, standard deviations, best/worst values,
    and majority voting results. It also performs bootstrap sampling to estimate statistics
    for different sample sizes.

    Args:
        data_sources: List of data source identifiers for each sample.
        sample_inputs: List of input prompts corresponding to each sample.
        infos_dict: Dictionary mapping variable names to lists of values for each sample.
        seed: Random seed for bootstrap sampling. Defaults to 42.

    Returns:
        A nested dictionary with the structure:
        {
            data_source: {
                variable_name: {
                    metric_name: value
                }
            }
        }

        Where metric_name includes:
        - "mean@N": Mean value across N samples
        - "std@N": Standard deviation across N samples
        - "best@N/mean": Mean of the best values in bootstrap samples of size N
        - "best@N/std": Standard deviation of the best values in bootstrap samples
        - "worst@N/mean": Mean of the worst values in bootstrap samples
        - "worst@N/std": Standard deviation of the worst values in bootstrap samples
        - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
        - "maj@N/std": Standard deviation of majority voting results (if "pred" exists)

    Example:
        >>> data_sources = ["source1", "source1", "source2"]
        >>> sample_inputs = ["prompt1", "prompt1", "prompt2"]
        >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
        >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict)
        >>> # result will contain statistics for each data source and variable
    """
    # Group metrics by data source, prompt and variable
    data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for sample_idx, data_source in enumerate(data_sources):
        prompt = sample_inputs[sample_idx]
        var2vals = data_src2prompt2var2vals[data_source][prompt]
        for var_name, var_vals in infos_dict.items():
            var2vals[var_name].append(var_vals[sample_idx])

    # Calculate metrics for each group
    data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for data_source, prompt2var2vals in data_src2prompt2var2vals.items():
        for prompt, var2vals in prompt2var2vals.items():
            for var_name, var_vals in var2vals.items():
                if isinstance(var_vals[0], str):
                    continue

                metric = {}
                n_resps = len(var_vals)
                metric[f"mean@{n_resps}"] = np.mean(var_vals)

                if n_resps > 1:
                    metric[f"std@{n_resps}"] = np.std(var_vals)

                    ns = []
                    n = 2
                    while n < n_resps:
                        ns.append(n)
                        n *= 2
                    ns.append(n_resps)

                    for n in ns:
                        [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
                            data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
                        )
                        metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
                        metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
                        if var2vals.get("pred", None) is not None:
                            vote_data = [
                                {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
                            ]
                            [(maj_n_mean, maj_n_std)] = bootstrap_metric(
                                data=vote_data,
                                subset_size=n,
                                reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
                                seed=seed,
                            )
                            metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std

                data_src2prompt2var2metric[data_source][prompt][var_name] = metric

    # Aggregate metrics across prompts
    data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for data_source, prompt2var2metric in data_src2prompt2var2metric.items():
        for prompt, var2metric in prompt2var2metric.items():
            for var_name, metric in var2metric.items():
                for metric_name, metric_val in metric.items():
                    data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)

    data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items():
        for var_name, metric2prompt_vals in var2metric2prompt_vals.items():
            for metric_name, prompt_vals in metric2prompt_vals.items():
                data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals)

    return data_src2var2metric2val
