"""
Compute tracker for accumulating FLOPs and other compute metrics across training/evaluation.

This module provides a thread-safe accumulator for tracking router compute usage
across different phases (training vs evaluation) and experiences.
"""

from typing import Dict, Optional
from collections import defaultdict
import threading


class ComputeTracker:
    """
    Thread-safe accumulator for router compute metrics.
    
    Tracks FLOPs and other compute metrics separately for:
    - Training vs Evaluation
    - Different experiences (apibench, mllm, etc.)
    - Per-batch accumulation
    
    Usage:
        tracker = ComputeTracker()
        
        # During training/evaluation
        compute_metrics = router.forward(..., return_compute_metrics=True)[1]
        tracker.accumulate(compute_metrics, phase="training", experience="mllm")
        
        # Get totals
        totals = tracker.get_totals()
    """
    
    def __init__(self):
        """Initialize empty tracker."""
        self._lock = threading.Lock()
        self._reset()
    
    def _reset(self):
        """Reset all accumulators (internal use)."""
        # Per-phase totals: "training" or "evaluation"
        self._phase_totals: Dict[str, Dict[str, float]] = defaultdict(lambda: {
            "total_flops": 0.0,
            "total_examples": 0,
            "total_batches": 0,
            "prompt_encoding_flops": 0.0,
            "normalization_flops": 0.0,
            "bmm_flops": 0.0,
        })
        
        # Per-experience totals: nested dict [phase][experience]
        self._experience_totals: Dict[str, Dict[str, Dict[str, float]]] = defaultdict(
            lambda: defaultdict(lambda: {
                "total_flops": 0.0,
                "total_examples": 0,
                "total_batches": 0,
                "prompt_encoding_flops": 0.0,
                "normalization_flops": 0.0,
                "bmm_flops": 0.0,
            })
        )
        
        # Overall totals
        self._overall_totals: Dict[str, float] = {
            "total_flops": 0.0,
            "total_examples": 0,
            "total_batches": 0,
            "prompt_encoding_flops": 0.0,
            "normalization_flops": 0.0,
            "bmm_flops": 0.0,
        }
    
    def accumulate(
        self,
        compute_metrics: Dict[str, float],
        phase: str = "training",
        experience: Optional[str] = None,
    ):
        """
        Accumulate compute metrics from a single forward pass.
        
        Args:
            compute_metrics: Dictionary from router.forward() with return_compute_metrics=True
            phase: "training" or "evaluation"
            experience: Optional experience name (e.g., "mllm", "apibench")
        """
        with self._lock:
            batch_size = int(compute_metrics.get("batch_size", 1))
            flops = compute_metrics.get("flops", 0.0)
            
            # Update phase totals
            self._phase_totals[phase]["total_flops"] += flops
            self._phase_totals[phase]["total_examples"] += batch_size
            self._phase_totals[phase]["total_batches"] += 1
            self._phase_totals[phase]["prompt_encoding_flops"] += compute_metrics.get("prompt_encoding_flops", 0.0)
            self._phase_totals[phase]["normalization_flops"] += compute_metrics.get("normalization_flops", 0.0)
            self._phase_totals[phase]["bmm_flops"] += compute_metrics.get("bmm_flops", 0.0)
            
            # Update experience totals if provided
            if experience is not None:
                self._experience_totals[phase][experience]["total_flops"] += flops
                self._experience_totals[phase][experience]["total_examples"] += batch_size
                self._experience_totals[phase][experience]["total_batches"] += 1
                self._experience_totals[phase][experience]["prompt_encoding_flops"] += compute_metrics.get("prompt_encoding_flops", 0.0)
                self._experience_totals[phase][experience]["normalization_flops"] += compute_metrics.get("normalization_flops", 0.0)
                self._experience_totals[phase][experience]["bmm_flops"] += compute_metrics.get("bmm_flops", 0.0)
            
            # Update overall totals
            self._overall_totals["total_flops"] += flops
            self._overall_totals["total_examples"] += batch_size
            self._overall_totals["total_batches"] += 1
            self._overall_totals["prompt_encoding_flops"] += compute_metrics.get("prompt_encoding_flops", 0.0)
            self._overall_totals["normalization_flops"] += compute_metrics.get("normalization_flops", 0.0)
            self._overall_totals["bmm_flops"] += compute_metrics.get("bmm_flops", 0.0)
    
    def get_totals(self) -> Dict[str, any]:
        """
        Get all accumulated totals.
        
        Returns:
            Dictionary with:
            - overall: Overall totals across all phases/experiences
            - by_phase: Totals per phase (training/evaluation)
            - by_experience: Nested dict [phase][experience] with totals
        """
        with self._lock:
            # Compute normalized metrics
            overall = self._overall_totals.copy()
            if overall["total_examples"] > 0:
                overall["flops_per_example"] = overall["total_flops"] / overall["total_examples"]
            else:
                overall["flops_per_example"] = 0.0
            
            by_phase = {}
            for phase, totals in self._phase_totals.items():
                phase_dict = totals.copy()
                if phase_dict["total_examples"] > 0:
                    phase_dict["flops_per_example"] = phase_dict["total_flops"] / phase_dict["total_examples"]
                else:
                    phase_dict["flops_per_example"] = 0.0
                by_phase[phase] = phase_dict
            
            by_experience = {}
            for phase, experiences in self._experience_totals.items():
                by_experience[phase] = {}
                for exp, totals in experiences.items():
                    exp_dict = totals.copy()
                    if exp_dict["total_examples"] > 0:
                        exp_dict["flops_per_example"] = exp_dict["total_flops"] / exp_dict["total_examples"]
                    else:
                        exp_dict["flops_per_example"] = 0.0
                    by_experience[phase][exp] = exp_dict
            
            return {
                "overall": overall,
                "by_phase": by_phase,
                "by_experience": by_experience,
            }
    
    def get_summary(self) -> Dict[str, any]:
        """
        Get a summary suitable for logging/saving to JSON.
        
        Returns:
            Dictionary with compute summary including normalized metrics.
        """
        totals = self.get_totals()
        
        # Format for paper reporting
        summary = {
            "total_flops": totals["overall"]["total_flops"],
            "total_flops_gflops": totals["overall"]["total_flops"] / 1e9,
            "total_examples": totals["overall"]["total_examples"],
            "flops_per_example": totals["overall"]["flops_per_example"],
            "total_batches": totals["overall"]["total_batches"],
            "breakdown": {
                "prompt_encoding": totals["overall"]["prompt_encoding_flops"],
                "normalization": totals["overall"]["normalization_flops"],
                "bmm": totals["overall"]["bmm_flops"],
            },
            "by_phase": totals["by_phase"],
            "by_experience": totals["by_experience"],
        }
        
        return summary
    
    def reset(self):
        """Reset all accumulators."""
        with self._lock:
            self._reset()
