# Copyright 2024 Bytedance Ltd. and/or its affiliates
from typing import Any
from collections import defaultdict

import numpy as np
import torch

from ..protocol import DataProto


def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
    return {key: np.mean(value) for key, value in metrics.items()}


def compute_length_metrics(batch: DataProto) -> dict[str, Any]:
    max_response_length = batch.batch["responses"].size(-1)
    max_prompt_length = batch.batch["attention_mask"].size(-1) - max_response_length

    prompt_length = batch.batch["attention_mask"][:, :-max_response_length].sum(-1).float()
    response_length = batch.batch["attention_mask"][:, -max_response_length:].sum(-1).float()

    return {
        "response_length/mean": torch.mean(response_length).detach().item(),
        "response_length/max": torch.max(response_length).detach().item(),
        "response_length/min": torch.min(response_length).detach().item(),
        "response_length/clip_ratio": torch.eq(response_length, max_response_length).float().mean().detach().item(),
        "prompt_length/mean": torch.mean(prompt_length).detach().item(),
        "prompt_length/max": torch.max(prompt_length).detach().item(),
        "prompt_length/min": torch.min(prompt_length).detach().item(),
        "prompt_length/clip_ratio": torch.eq(prompt_length, max_prompt_length).float().mean().detach().item(),
    }


def compute_grpo_group_statistics(batch: DataProto) -> dict[str, Any]:
    sequence_reward = batch.batch["token_level_rewards"].sum(-1)
    advantages = batch.batch["advantages"]
    
    max_response_length = batch.batch["responses"].size(-1)
    response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
    
    valid_adv = torch.masked_select(advantages, response_mask)
    
    if "uid" in batch.non_tensor_batch:
        indices = batch.non_tensor_batch["uid"]
    else:
        indices = torch.arange(sequence_reward.shape[0], device=sequence_reward.device)
    
    id2rewards = defaultdict(list)
    id2advantages = defaultdict(list)
    
    for i in range(sequence_reward.shape[0]):
        idx = indices[i] if isinstance(indices, torch.Tensor) else indices[i]
        id2rewards[idx].append(sequence_reward[i].detach().cpu().item())
        seq_adv = advantages[i][response_mask[i]]
        if seq_adv.numel() > 0:
            id2advantages[idx].append(torch.mean(seq_adv).detach().cpu().item())
    
    group_reward_stds = []
    group_reward_means = []
    group_adv_means = []
    all_zero_groups = 0
    all_one_groups = 0
    total_groups = len(id2rewards)
    
    for idx in id2rewards:
        rewards = np.array(id2rewards[idx])
        group_std = np.std(rewards)
        group_mean = np.mean(rewards)
        
        group_reward_stds.append(group_std)
        group_reward_means.append(group_mean)
        
        if idx in id2advantages:
            group_adv_means.append(np.mean(id2advantages[idx]))
        
        if len(rewards) > 0:
            if np.allclose(rewards, 0.0, atol=1e-6):
                all_zero_groups += 1
            elif np.allclose(rewards, 1.0, atol=1e-6):
                all_one_groups += 1
    
    if len(group_reward_stds) == 0:
        return {}
    
    group_reward_stds = np.array(group_reward_stds)
    group_adv_means = np.array(group_adv_means) if group_adv_means else np.array([])
    
    stats = {
        "grpo_stats/group_reward_std/mean": float(np.mean(group_reward_stds)),
        "grpo_stats/group_reward_std/std": float(np.std(group_reward_stds)),
        "grpo_stats/group_reward_std/min": float(np.min(group_reward_stds)),
        "grpo_stats/group_reward_std/max": float(np.max(group_reward_stds)),
        "grpo_stats/group_reward_std/median": float(np.median(group_reward_stds)),
        "grpo_stats/group_reward_std/p25": float(np.percentile(group_reward_stds, 25)),
        "grpo_stats/group_reward_std/p75": float(np.percentile(group_reward_stds, 75)),
        "grpo_stats/group_reward_std/p1e-3": float(np.mean(group_reward_stds < 1e-3)),
        "grpo_stats/group_reward_std/p1e-2": float(np.mean(group_reward_stds < 1e-2)),
        "grpo_stats/group_reward_std/p1e-1": float(np.mean(group_reward_stds < 1e-1)),
        
        "grpo_stats/all_zero_group_ratio": all_zero_groups / total_groups if total_groups > 0 else 0.0,
        "grpo_stats/all_one_group_ratio": all_one_groups / total_groups if total_groups > 0 else 0.0,
        "grpo_stats/all_equal_group_ratio": (all_zero_groups + all_one_groups) / total_groups if total_groups > 0 else 0.0,
        "grpo_stats/total_groups": total_groups,
        
        "grpo_stats/advantage_abs_mean": float(np.mean(np.abs(valid_adv.cpu().numpy()))),
        "grpo_stats/advantage_abs_std": float(np.std(np.abs(valid_adv.cpu().numpy()))),
        "grpo_stats/advantage_abs_max": float(torch.max(torch.abs(valid_adv)).item()),
        "grpo_stats/advantage_abs_min": float(torch.min(torch.abs(valid_adv)).item()),
        "grpo_stats/advantage_abs_median": float(torch.median(torch.abs(valid_adv)).item()),
        "grpo_stats/advantage_abs_p0.1": float(torch.mean((torch.abs(valid_adv) < 0.1).float()).item()),
        "grpo_stats/advantage_abs_p0.5": float(torch.mean((torch.abs(valid_adv) < 0.5).float()).item()),
        "grpo_stats/advantage_abs_p1.0": float(torch.mean((torch.abs(valid_adv) < 1.0).float()).item()),
        "grpo_stats/advantage_abs_p2.0": float(torch.mean((torch.abs(valid_adv) < 2.0).float()).item()),
    }
    
    group_adv_abs_means = []
    group_reward_stds_ordered = []
    
    for idx in id2rewards:
        if idx in id2advantages and len(id2advantages[idx]) > 0:
            group_adv_abs_means.append(np.mean(np.abs(id2advantages[idx])))
            rewards = np.array(id2rewards[idx])
            group_reward_stds_ordered.append(np.std(rewards))
    
    if len(group_adv_abs_means) > 1 and len(group_adv_abs_means) == len(group_reward_stds_ordered):
        correlation = np.corrcoef(group_reward_stds_ordered, group_adv_abs_means)[0, 1]
        stats["grpo_stats/corr_group_std_adv_abs"] = float(correlation) if not np.isnan(correlation) else 0.0
    else:
        stats["grpo_stats/corr_group_std_adv_abs"] = 0.0
    
    return stats


def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict[str, Any]:
    sequence_score = batch.batch["token_level_scores"].sum(-1)
    sequence_reward = batch.batch["token_level_rewards"].sum(-1)

    advantages = batch.batch["advantages"]
    returns = batch.batch["returns"]

    max_response_length = batch.batch["responses"].size(-1)
    response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()

    valid_adv = torch.masked_select(advantages, response_mask)
    valid_returns = torch.masked_select(returns, response_mask)

    if use_critic:
        values = batch.batch["values"]
        valid_values = torch.masked_select(values, response_mask)
        return_diff_var = torch.var(valid_returns - valid_values)
        return_var = torch.var(valid_returns)

    grpo_stats = compute_grpo_group_statistics(batch)

    return {
        "critic/score/mean": torch.mean(sequence_score).detach().item(),
        "critic/score/max": torch.max(sequence_score).detach().item(),
        "critic/score/min": torch.min(sequence_score).detach().item(),
        "critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
        "critic/rewards/max": torch.max(sequence_reward).detach().item(),
        "critic/rewards/min": torch.min(sequence_reward).detach().item(),
        "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
        "critic/advantages/max": torch.max(valid_adv).detach().item(),
        "critic/advantages/min": torch.min(valid_adv).detach().item(),
        "critic/returns/mean": torch.mean(valid_returns).detach().item(),
        "critic/returns/max": torch.max(valid_returns).detach().item(),
        "critic/returns/min": torch.min(valid_returns).detach().item(),
        **(
            {
                "critic/values/mean": torch.mean(valid_values).detach().item(),
                "critic/values/max": torch.max(valid_values).detach().item(),
                "critic/values/min": torch.min(valid_values).detach().item(),
                "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
            }
            if use_critic
            else {}
        ),
        **compute_length_metrics(batch),
        **grpo_stats,
    }


def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:
    num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
    num_overall_tokens = sum(batch.meta_info["global_token_num"])
    num_tokens_of_section = {
        **dict.fromkeys(["gen", "reward"], num_response_tokens),
        **dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
    }
    return {
        **{f"timing_s/{name}": value for name, value in timing_raw.items()},
        **{
            f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
        },
    }


def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], num_gpus: int) -> dict[str, Any]:
    total_num_tokens = sum(batch.meta_info["global_token_num"])
    time = timing_raw["step"]
    return {
        "perf/total_num_tokens": total_num_tokens,
        "perf/time_per_step": time,
        "perf/throughput": total_num_tokens / (time * num_gpus),
    }
