import copy
import math
import torch
import wandb

@torch.no_grad()
def _to_fp32_cpu(t: torch.Tensor) -> torch.Tensor:
    return t.detach().to(dtype=torch.float32, device="cpu")

def rel_minus_change(Ak, A0):
    diff = Ak - A0
    norm_A0 = torch.linalg.norm(A0.flatten(), ord=2)
    norm_diff = torch.linalg.norm(diff.flatten(), ord=2)
    return (norm_diff / norm_A0).item()

# ==========================================
# CONFIGURATION FOR DETAILED CONSERVATION METRICS
# ==========================================
# MHA Group 1: Q_i^T Q_i + K_i^T K_i, V_i^T V_i + O_i^T O_i, i in [h]
DETAILED_MHA_METRIC_GROUPS = {
    "mha_group1": [
        lambda Q, K, V, O: torch.bmm(Q, Q.transpose(-1, -2)) + torch.bmm(K, K.transpose(-1, -2)),
        lambda Q, K, V, O: torch.bmm(V, V.transpose(-1, -2)) + torch.bmm(O, O.transpose(-1, -2)),
    ],
}

# FFN Group 1: ||A_{:,i}||^2 + ||C_{i,:}||^2, i in [d_1]
DETAILED_FFN_METRIC_GROUPS = {
    "ffn_group1": [lambda A, B, C: A * A - C.transpose(-1, -2) * C.transpose(-1, -2)],
}

@torch.no_grad()
def precompute_h0_values(start_model, config):
    """
    Precompute and cache all h0 values from the start model.
    This should be called once at the beginning of training.
    """
    cached_h0 = {
        'mha_h0_stack': [],
        'ffn_h0_stack': [],
        'N_mha_count': 0,
        'N_ffn_count': 0,
        # Dictionary to store stacks for each detailed metric
        'detailed_mha_stacks': {name: [[] for _ in funcs] for name, funcs in DETAILED_MHA_METRIC_GROUPS.items()},
        'detailed_ffn_stacks': {name: [[] for _ in funcs] for name, funcs in DETAILED_FFN_METRIC_GROUPS.items()},
    }
    
    layers0 = start_model.vit.encoder.layer
    
    mha_layers_list = []
    
    # === MHA Processing ===
    for l in range(config.num_hidden_layers):
        attn0 = layers0[l].attention.attention
        out0 = layers0[l].attention.output.dense
        
        Wq0 = _to_fp32_cpu(attn0.query.weight)
        Wk0 = _to_fp32_cpu(attn0.key.weight)
        Wv0 = _to_fp32_cpu(attn0.value.weight)
        Wo0 = _to_fp32_cpu(out0.weight)
        
        Q0_heads = Wq0.view(config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size)
        K0_heads = Wk0.view(config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size)
        V0_heads = Wv0.view(config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size)
        O0_heads = Wo0.view(config.hidden_size, config.num_attention_heads, config.hidden_size // config.num_attention_heads).permute(1, 2, 0)
        
        # Main Conservation Metric Calculation (Method 1 & 2 Core)
        # Using bmm to perform calculation for each head separately [H, D_h, D] x [H, D, D_h] -> [H, D_h, D_h]
        term_qq = torch.bmm(Q0_heads, Q0_heads.transpose(-1, -2))
        term_kk = torch.bmm(K0_heads, K0_heads.transpose(-1, -2))
        h10 = term_qq - term_kk
        
        term_vv = torch.bmm(V0_heads, V0_heads.transpose(-1, -2))
        term_oo = torch.bmm(O0_heads, O0_heads.transpose(-1, -2))
        h20 = term_vv - term_oo
        
        h0 = h10 + h20
        mha_layers_list.append(h0)
        
        # Calculate and store Detailed/Debug Terms for MHA
        for group_name, funcs in DETAILED_MHA_METRIC_GROUPS.items():
            for i, func in enumerate(funcs):
                term = func(Q0_heads, K0_heads, V0_heads, O0_heads)
                cached_h0['detailed_mha_stacks'][group_name][i].append(term)
        
    # Stack [L, H, D, D] for main metrics
    cached_h0['mha_h0_stack'] = torch.stack(mha_layers_list)
    cached_h0['N_mha_count'] = cached_h0['mha_h0_stack'].shape[0] * cached_h0['mha_h0_stack'].shape[1]
    
    # Stack [L, H, ...] for detailed MHA metrics
    for group_name, funcs in DETAILED_MHA_METRIC_GROUPS.items():
        for i in range(len(funcs)):
            cached_h0['detailed_mha_stacks'][group_name][i] = torch.stack(cached_h0['detailed_mha_stacks'][group_name][i])

    ffn_layers_list = []
    
    # === FFN Processing ===
    for l in range(config.num_hidden_layers):
        A0 = _to_fp32_cpu(layers0[l].ffn.down_proj.weight)
        C0 = _to_fp32_cpu(layers0[l].ffn.up_proj.weight)
        
        # Checking for Gate Proj (B)
        if hasattr(layers0[l].ffn, "gate_proj"):
            B0 = _to_fp32_cpu(layers0[l].ffn.gate_proj.weight)
        else:
            # Fallback for models without gate projection
            B0 = torch.zeros_like(C0)

        d, d_1 = A0.shape
        C0_row_sum = (C0 * C0).sum(dim=1)
        # Main Conservation (h) for local theorem
        # Theorem 3.2: h is function of ||A_{:,i}||^2 - ||C_{i,:}||^2
        A0_col_sum = (A0 * A0).sum(dim=0)
        h0_layer = A0_col_sum - C0_row_sum
        
        ffn_layers_list.append(h0_layer)
        
        # Calculate and store Detailed/Debug Terms for FFN
        for group_name, funcs in DETAILED_FFN_METRIC_GROUPS.items():
            for i, func in enumerate(funcs):
                term = func(A0, B0, C0)
                cached_h0['detailed_ffn_stacks'][group_name][i].append(term)
    
    # Stack FFN: [layers, intermediate_size]
    cached_h0['ffn_h0_stack'] = torch.stack(ffn_layers_list)
    cached_h0['N_ffn_count'] = cached_h0['ffn_h0_stack'].numel()
    
    # Stack [L, ...] for detailed FFN metrics
    for group_name, funcs in DETAILED_FFN_METRIC_GROUPS.items():
        for i in range(len(funcs)):
            cached_h0['detailed_ffn_stacks'][group_name][i] = torch.stack(cached_h0['detailed_ffn_stacks'][group_name][i])
    
    return cached_h0

@torch.no_grad()
def conservation_log(global_step, cached_h0, current_model, log_type, config, args):
    """
    Compute conservation metrics using precomputed h0 values.
    
    Args:
        global_step: Current training step
        cached_h0: Precomputed h0 values from precompute_h0_values()
        current_model: Current model state
        log_type: Type of logging ('step' or 'epoch')
        config: Model configuration
        args: Training arguments
    """
    metrics = {}
    layersk = current_model.vit.encoder.layer

    epsilon = 0
    
    mha_layers_list = []
    
    # Initialize detailed metric accumulators
    detailed_metrics_acc = {
        name: {'m2_sum': 0.0, 'count': 0} 
        for name in list(DETAILED_MHA_METRIC_GROUPS.keys()) + list(DETAILED_FFN_METRIC_GROUPS.keys())
    }

    # === MHA Processing ===
    for l in range(config.num_hidden_layers):
        attnk = layersk[l].attention.attention
        outk = layersk[l].attention.output.dense
        
        Wqk = _to_fp32_cpu(attnk.query.weight)
        Wkk = _to_fp32_cpu(attnk.key.weight)
        Wvk = _to_fp32_cpu(attnk.value.weight)
        Wok = _to_fp32_cpu(outk.weight)
        
        Qk_heads = Wqk.view(config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size)
        Kk_heads = Wkk.view(config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size)
        Vk_heads = Wvk.view(config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size)
        Ok_heads = Wok.view(config.hidden_size, config.num_attention_heads, config.hidden_size // config.num_attention_heads).permute(1, 2, 0)
        
        # Main Conservation Calculation
        term_qq_k = torch.bmm(Qk_heads, Qk_heads.transpose(-1, -2))
        term_kk_k = torch.bmm(Kk_heads, Kk_heads.transpose(-1, -2))
        h1k = term_qq_k - term_kk_k
        
        term_vv_k = torch.bmm(Vk_heads, Vk_heads.transpose(-1, -2))
        term_oo_k = torch.bmm(Ok_heads, Ok_heads.transpose(-1, -2))
        h2k = term_vv_k - term_oo_k
        
        hk = h1k + h2k
        mha_layers_list.append(hk)
        
        # Detailed/Debug Logging Loop for MHA
        for group_name, funcs in DETAILED_MHA_METRIC_GROUPS.items():
            for i, func in enumerate(funcs):
                # Compute current term [Heads, ...]
                term_k = func(Qk_heads, Kk_heads, Vk_heads, Ok_heads)
                
                # Retrieve cached term stack: [Layers, Heads, ...] -> Select Layer [l] -> [Heads, ...]
                term_0_layer = cached_h0['detailed_mha_stacks'][group_name][i][l]
                
                # Iterate over heads to accumulate M2
                head_diffs = term_k - term_0_layer
                diff_norms = torch.linalg.norm(head_diffs.flatten(start_dim=1), dim=1)
                h0_norms = torch.linalg.norm(term_0_layer.flatten(start_dim=1), dim=1)
                
                m2_vals = diff_norms / (h0_norms + epsilon)
                
                detailed_metrics_acc[group_name]['m2_sum'] += m2_vals.sum().item()
                detailed_metrics_acc[group_name]['count'] += config.num_attention_heads

    # Stack: [L, H, D, D]
    mha_hk_stack = torch.stack(mha_layers_list)
    
    ffn_layers_list = []
    
    # === FFN Processing ===
    for l in range(config.num_hidden_layers):
        Ak = _to_fp32_cpu(layersk[l].ffn.down_proj.weight)
        Ck = _to_fp32_cpu(layersk[l].ffn.up_proj.weight)
        
        if hasattr(layersk[l].ffn, "gate_proj"):
            Bk = _to_fp32_cpu(layersk[l].ffn.gate_proj.weight)
        else:
            Bk = torch.zeros_like(Ck)

        # Ak shape: [hidden_size, intermediate_size]
        d, d_1 = Ak.shape
        Ck_row_sum = (Ck * Ck).sum(dim=1)
        Ak_col_sum = (Ak * Ak).sum(dim=0)
        hk_layer = Ak_col_sum - Ck_row_sum
        
        ffn_layers_list.append(hk_layer)
        
        # Detailed/Debug Logging Loop for FFN
        for group_name, funcs in DETAILED_FFN_METRIC_GROUPS.items():
            for i, func in enumerate(funcs):
                # Compute current term
                term_k = func(Ak, Bk, Ck)
                # Retrieve cached term
                term_0 = cached_h0['detailed_ffn_stacks'][group_name][i][l]
                
                # Calculate Method 2 for this element (Elementwise Relative Error)
                # Formula: abs(h_k - h_0) / abs(h_0)
                diff = (term_k - term_0).abs()
                denom = term_0.abs() + epsilon
                
                m2_vals = diff / denom
                
                detailed_metrics_acc[group_name]['m2_sum'] += m2_vals.sum().item()
                detailed_metrics_acc[group_name]['count'] += term_k.numel()

    # Stack: [L, intermediate_size]
    ffn_hk_stack = torch.stack(ffn_layers_list)

    # Log Detailed Metrics
    for name, acc in detailed_metrics_acc.items():
        if acc['count'] > 0:
            metrics[f"{log_type}/detailed/{name}"] = acc['m2_sum'] / acc['count']

    # ============ METHOD 2 (Vectorized) ============
    # Formula M2: 1/N * sum(norm(hk_i - h0_i)/norm(h0_i))
    
    # --- MHA ---
    # MHA Diff: [L, H, D, D]
    mha_diff = mha_hk_stack - cached_h0['mha_h0_stack']
    # Matrix norm per head: [L, H]
    mha_diff_norms = torch.linalg.norm(mha_diff, dim=(-2, -1))
    mha_h0_norms = torch.linalg.norm(cached_h0['mha_h0_stack'], dim=(-2, -1))
    
    # Method 2 (MHA)
    mha_rel_m2 = mha_diff_norms / (mha_h0_norms + epsilon)

    # --- FFN ---
    # FFN Diff: [L, intermediate_size]
    ffn_diff = ffn_hk_stack - cached_h0['ffn_h0_stack']
    
    # Method 2 (FFN) - Per scalar element
    ffn_rel_m2 = ffn_diff.abs() / (cached_h0['ffn_h0_stack'].abs() + epsilon)
    
    # --- Sums & Counts ---
    # Sums for layers [L]
    mha_layer_sums_m2 = mha_rel_m2.sum(dim=1)
    ffn_layer_sums_m2 = ffn_rel_m2.sum(dim=1)
    
    # Layer-wise metrics
    for l in range(config.num_hidden_layers):
        N_mha_layer = config.num_attention_heads 
        N_ffn_layer = cached_h0['ffn_h0_stack'][l].numel()
        N_layer_virtual = N_mha_layer + N_ffn_layer
        
        # Method 2
        mha_val_m2 = mha_layer_sums_m2[l].item()
        ffn_val_m2 = ffn_layer_sums_m2[l].item()
        metrics[f"{log_type}/layer_{l}"] = (mha_val_m2 + ffn_val_m2) / N_layer_virtual
        metrics[f"{log_type}/mha/layer_{l}"] = mha_val_m2 / N_mha_layer
        metrics[f"{log_type}/ffn/layer_{l}"] = ffn_val_m2 / N_ffn_layer
        
    # Global metrics
    mha_total_sum_m2 = mha_layer_sums_m2.sum().item()
    ffn_total_sum_m2 = ffn_layer_sums_m2.sum().item()
    
    N_mha_total = cached_h0['N_mha_count']
    N_ffn_total = cached_h0['N_ffn_count']
    N_virtual_total = N_mha_total + N_ffn_total
    
    # Method 2 Global
    metrics[f"{log_type}/model"] = (mha_total_sum_m2 + ffn_total_sum_m2) / N_virtual_total
    if N_mha_total > 0:
        metrics[f"{log_type}/mha"] = mha_total_sum_m2 / N_mha_total
    metrics[f"{log_type}/ffn"] = ffn_total_sum_m2 / N_ffn_total
    
    if wandb.run is not None:
        wandb.log(metrics, step=global_step)