from collections import defaultdict, deque
from torch.utils.tensorboard import SummaryWriter

class StatsTracker:
    def __init__(self, ema_alpha=0.1, log_interval=100, tensorboard=None):
        self.ema_alpha = ema_alpha
        self.log_interval = log_interval
        self.total_steps = 0

        self.sum_dict = defaultdict(float)
        self.count_dict = defaultdict(int)
        self.ema_dict = dict()
        self.latest_dict = dict()

        # Step buffer for interval average
        self.step_buffer = defaultdict(lambda: deque(maxlen=log_interval))

        if tensorboard is not None:
            self.summary_writer = SummaryWriter(tensorboard, flush_secs=10)
        else:
            self.summary_writer = None

    def update(self, stats: dict):
        self.total_steps += 1
        

        for key, val in stats.items():
            self.latest_dict[key] = val
            self.sum_dict[key] += val
            self.count_dict[key] += 1

            # EMA
            if key not in self.ema_dict:
                self.ema_dict[key] = val
            else:
                alpha = self.ema_alpha
                self.ema_dict[key] = alpha * val + (1 - alpha) * self.ema_dict[key]

            # Step buffer for recent average
            self.step_buffer[key].append(val)

            # tensorboard logging
            if self.summary_writer is not None:
                self.summary_writer.add_scalar(f"stats/{key}", val, self.total_steps)

    def get_stats(self):
        stats = {}
        for key in self.latest_dict.keys():
            avg = self.sum_dict[key] / self.count_dict[key] if self.count_dict[key] > 0 else 0.0
            step_avg = sum(self.step_buffer[key]) / len(self.step_buffer[key]) if self.step_buffer[key] else 0.0
            stats[key] = {
                "latest": self.latest_dict[key],
                "ema": self.ema_dict.get(key, 0.0),
                "avg": avg,
                "step_avg": step_avg,
            }
        return stats

    def format(self, precision=4, split=', '):
        stats = self.get_stats()
        parts = [f"{k}: {v['step_avg']:.{precision}f}" for k, v in stats.items()]
        return split.join(parts)

    def reset(self):
        self.__init__(ema_alpha=self.ema_alpha, log_interval=self.log_interval)
