from typing import Any
import numpy as np

from verl import DataProto
import verl.trainer.ppo.metric_utils as _metric_utils


FILEAGENT_TOOL_METRICS_KEY = "fileagent_tool_metrics"


def build_tool_metric(tool_name: str, succeeded: bool) -> dict[str, Any]:
    """build metric for single tool call"""
    metric = {
        "name": tool_name,
        "succeeded": succeeded,
    }
    return metric


def init_trajectory_tool_metrics(tool_names: list[str]) -> dict[str, dict[str, Any]]:
    """init metrics for single trajectory tool call"""
    metrics = {}
    for name in tool_names:
        metrics[name] = {"call_count": 0, "success_count": 0}
    metrics.update({
        "Overall": {"call_count": 0, "success_count": 0},
        "Unknown": {"call_count": 0, "success_count": 0},
    })
    return metrics


def update_trajectory_tool_metrics(trajectory_metrics: dict[str, dict[str, Any]], tool_metric: dict[str, Any]) -> None:
    """update metrics for single trajectory tool call"""
    tool_name = tool_metric.get("name", "Unknown")
    if tool_name not in trajectory_metrics:
        tool_name = "Unknown"
    trajectory_metrics["Overall"]["call_count"] += 1
    trajectory_metrics[tool_name]["call_count"] += 1
    if tool_metric.get("succeeded"):
        trajectory_metrics["Overall"]["success_count"] += 1
        trajectory_metrics[tool_name]["success_count"] += 1


_compute_data_metrics = _metric_utils.compute_data_metrics
def compute_fileagent_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:
    metrics = _compute_data_metrics(batch, use_critic)

    # Compute FileAgent tool metrics
    batch_tool_metrics = batch.non_tensor_batch[FILEAGENT_TOOL_METRICS_KEY]
    tool_names = list(batch_tool_metrics[0].keys())
    for tool_name in tool_names:
        call_count = np.array([tool_metrics[tool_name]["call_count"] for tool_metrics in batch_tool_metrics])
        success_count = np.array([tool_metrics[tool_name]["success_count"] for tool_metrics in batch_tool_metrics])
        metrics[f"tools/{tool_name}/call_count"] = np.sum(call_count)
        metrics[f"tools/{tool_name}/success_count"] = np.sum(success_count)
        metrics[f"tools/{tool_name}/call_ratio"] = np.mean(call_count)
        metrics[f"tools/{tool_name}/success_ratio"] = np.sum(success_count) / (np.sum(call_count) + 1e-6)

    return metrics
