# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import torch
from collections import defaultdict


class Telemetry:
    def __init__(self):
        def _create_tensor():
            return torch.tensor(data=0.0, dtype=torch.float32, device="cuda")
        self._sum   = defaultdict(_create_tensor)  # holds running sums
        self._count = defaultdict(_create_tensor)  # holds running counts

    def reset(self):
        self._sum.clear()
        self._count.clear()

    def add(self, key, value):
        """
        Assumption: We assume the values carry the same weight across ranks
            This is because `Telemetry` performs simple average, not weighted average, across ranks
        """
        # Verify that value is a 0-dim pytorch tensor
        assert torch.is_tensor(value)
        assert value.ndim == 0
        value = value.detach().to(torch.float32, copy=True)
        self._sum[key]   += value
        self._count[key] += 1.0

    def get(self, key) -> float:
        assert key in self._sum
        out = (self._sum[key] / self._count[key]).clone()
        torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.AVG)
        return float(out.item())
    
    def keys(self):
        return self._sum.keys()
