# --- NEW: flop_utils.py (or place near top of your file) ---------------------
import math
from dataclasses import dataclass
from typing import List, Optional


@dataclass
class ModelShape:
    d_model: int
    n_layer: int
    n_head: int
    d_ff: int
    vocab_size: int


def _shape_from_hf_config(model) -> ModelShape:
    cfg = model.config
    # GPT-2
    if hasattr(cfg, "n_embd") and hasattr(cfg, "n_head"):
        d_model = cfg.n_embd
        n_layer = cfg.n_layer
        n_head = cfg.n_head
        d_ff = getattr(cfg, "n_inner", None) or 4 * d_model
        vocab = getattr(
            cfg, "vocab_size", None
        ) or model.get_output_embeddings().weight.size(0)
        return ModelShape(d_model, n_layer, n_head, d_ff, vocab)
    # LLaMA (meta)
    if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_hidden_layers"):
        d_model = cfg.hidden_size
        n_layer = cfg.num_hidden_layers
        n_head = cfg.num_attention_heads
        d_ff = getattr(cfg, "intermediate_size", 4 * d_model)
        vocab = getattr(cfg, "vocab_size", model.get_output_embeddings().weight.size(0))
        return ModelShape(d_model, n_layer, n_head, d_ff, vocab)
    # Fallback (best effort)
    d_model = getattr(cfg, "hidden_size", getattr(cfg, "n_embd", None))
    n_layer = getattr(cfg, "num_hidden_layers", getattr(cfg, "n_layer", None))
    n_head = getattr(cfg, "num_attention_heads", getattr(cfg, "n_head", None))
    d_ff = getattr(cfg, "intermediate_size", 4 * d_model)
    vocab = getattr(cfg, "vocab_size", model.get_output_embeddings().weight.size(0))
    assert all(
        v is not None for v in [d_model, n_layer, n_head, d_ff, vocab]
    ), "Cannot infer model shape."
    return ModelShape(d_model, n_layer, n_head, d_ff, vocab)


def _human(x: float) -> str:
    units = ["FLOPs", "KFLOPs", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs", "EFLOPs"]
    i = 0
    while x >= 1000 and i < len(units) - 1:
        x /= 1000.0
        i += 1
    return f"{x:.3g} {units[i]}"


def per_layer_forward_flops(B: int, T: int, shape: ModelShape) -> float:
    """
    Approximate FLOPs for one transformer layer forward pass on a BxT sequence.

    Matmul FLOPs ~ 2*m*n*p
    """
    d = shape.d_model
    h = shape.n_head
    dk = d // h
    dff = shape.d_ff

    # Projections: Q, K, V: [B,T,d] x [d,d] each
    proj_qkv = 3 * (2 * B * T * d * d)

    # Attention scores QK^T and AV
    attn_scores = 2 * B * h * T * T * dk
    attn_av = 2 * B * h * T * T * dk
    attn_softmax = B * h * T * T  # cheap; keep for completeness

    # Output projection: [B,T,d] x [d,d]
    proj_out = 2 * B * T * d * d

    # MLP up and down: [B,T,d]x[d,dff] + [B,T,dff]x[dff,d]
    mlp = 2 * B * T * d * dff + 2 * B * T * dff * d

    return proj_qkv + attn_scores + attn_softmax + attn_av + proj_out + mlp


def lm_head_flops(
    B: int, T: int, shape: ModelShape, assume_full_sequence: bool = True
) -> float:
    """
    LM head matmul [B,T,d] x [d,V]. Your code computes logits for the entire sequence.
    """
    t = T if assume_full_sequence else 1
    return 2 * B * t * shape.d_model * shape.vocab_size


def full_model_forward_flops(
    B: int, T: int, shape: ModelShape, assume_full_sequence_lm_head: bool = True
) -> float:
    per_layer = per_layer_forward_flops(B, T, shape)
    core = shape.n_layer * per_layer
    head = lm_head_flops(B, T, shape, assume_full_sequence=assume_full_sequence_lm_head)
    return core + head


# ---- scan accounting ---------------------------------------------------------


@dataclass
class ScanSpec:
    name: str
    passes: List[tuple]  # list of (B, T, multiplicity)


def _ceil_div(a, b):
    return (a + b - 1) // b


def spec_layer_wise(batch_size: int, seq_len: int, n_layers_scanned: int) -> ScanSpec:
    # 1 baseline forward (_llm_forward) + one patched forward per scanned layer
    return ScanSpec(
        "LayerWise",
        passes=[(batch_size, seq_len, 1), (batch_size, seq_len, n_layers_scanned)],
    )


def spec_microsaccades(batch_size: int, seq_len: int) -> ScanSpec:
    # 1 baseline forward + 1 intervention forward (you re-run all layers once)
    return ScanSpec(
        "Microsaccades", passes=[(batch_size, seq_len, 1), (batch_size, seq_len, 1)]
    )


def spec_token_wise_per_string(seq_len: int, batch_tokens: int) -> ScanSpec:
    # Per string: 1 baseline forward + ceil(T/bs) patched forwards, each with batch = bs (except last)
    n_batches = _ceil_div(seq_len, batch_tokens)
    return ScanSpec(
        "TokenWise(per-string)",
        passes=[(1, seq_len, 1), (batch_tokens, seq_len, n_batches)],
    )


def estimate_scan_flops(
    model, seq_len: int, scan_spec: ScanSpec, assume_full_sequence_lm_head=True
) -> float:
    shape = _shape_from_hf_config(model)
    total = 0.0
    for B, T, m in scan_spec.passes:
        total += m * full_model_forward_flops(B, T, shape, assume_full_sequence_lm_head)
    return total


# --- END flop_utils.py --------------------------------------------------------
