# profiling/timer_profiler.py
"""
Timing and lightweight FLOP estimator utilities.

Includes:
  - TimerProfiler: accumulate timings, report mean/median/p95
  - time_it(fn, *args, **kwargs): helper to time a call
  - estimate_simple_flops(model, input_shape): very coarse heuristic FLOP estimator
    (counts multiply-add approximations for Linear and Conv layers by parameter shapes).
Note:
  - For rigorous FLOP counting use external tools (fvcore, ptflops). The estimator here is a cheap heuristic.
"""
import time
import numpy as np
from typing import Callable, Any, Dict
import torch
import torch.nn as nn

class TimerProfiler:
    def __init__(self):
        self.samples = []

    def time_fn(self, fn: Callable, *args, **kwargs):
        """
        Run function fn and record elapsed wall-clock time.
        Returns fn's result.
        """
        t0 = time.time()
        out = fn(*args, **kwargs)
        t1 = time.time()
        self.samples.append(t1 - t0)
        return out

    def summary(self) -> Dict[str, float]:
        arr = np.array(self.samples) if len(self.samples) > 0 else np.array([0.0])
        return {
            "count": int(len(self.samples)),
            "mean_s": float(arr.mean()),
            "median_s": float(np.median(arr)),
            "p95_s": float(np.percentile(arr, 95)),
            "total_s": float(arr.sum())
        }

def time_it(fn: Callable, *args, n_runs: int = 10, **kwargs) -> Dict[str, Any]:
    """
    Utility to time a function multiple times and return summary statistics and last return value.
    """
    profiler = TimerProfiler()
    res = None
    for _ in range(n_runs):
        res = profiler.time_fn(fn, *args, **kwargs)
    return {"summary": profiler.summary(), "last_result": res}

def estimate_simple_flops(module: nn.Module, input_shape: tuple) -> float:
    """
    Very coarse FLOP estimator for a model given input shape (B,C,H,W) or (B, D).
    Heuristic rules:
      - Linear layer: 2 * in_features * out_features (mul + add)
      - Conv2d: 2 * Cout * Cin * Kx * Ky * out_H * out_W
      - BatchNorm/Activation/Pool ignored or counted small constant
    This is only a rough proxy and should NOT replace dedicated FLOP counters.
    """
    total_flops = 0
    # create a dummy input and run one forward while tracing shapes
    device = next(module.parameters()).device if any(p.requires_grad for p in module.parameters()) else "cpu"
    try:
        module.eval()
        with torch.no_grad():
            x = torch.zeros(input_shape).to(device)
            hooks = []

            def conv_hook(self, input, output):
                # input[0] shape: [B, C_in, H_in, W_in]
                B = output.shape[0]
                Cout = output.shape[1]
                out_h = output.shape[2]
                out_w = output.shape[3]
                Cin = input[0].shape[1]
                k_h, k_w = self.kernel_size
                flops = 2 * Cin * Cout * k_h * k_w * out_h * out_w
                nonlocal total_flops
                total_flops += flops * B

            def linear_hook(self, input, output):
                B = input[0].shape[0]
                in_f = self.in_features
                out_f = self.out_features
                flops = 2 * in_f * out_f
                nonlocal total_flops
                total_flops += flops * B

            for module_ in module.modules():
                if isinstance(module_, nn.Conv2d):
                    hooks.append(module_.register_forward_hook(conv_hook))
                elif isinstance(module_, nn.Linear):
                    hooks.append(module_.register_forward_hook(linear_hook))

            _ = module(x)

            for h in hooks:
                h.remove()

    except Exception:
        # fallback: estimate by counting params ~ 2*params per multiply-add
        total_params = sum(p.numel() for p in module.parameters())
        total_flops = 2 * total_params

    return float(total_flops)
