import torch
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict
import copy

@torch.no_grad()
def _to_fp32_cpu(t: torch.Tensor) -> torch.Tensor:
    """Convert tensor to float32 on CPU"""
    return t.detach().to(dtype=torch.float32, device="cpu")


def precompute_h0_values(start_model, config):
    h0_values = {}
    
    layers0 = start_model.model.layers
    head_dim = config.hidden_size // config.num_attention_heads
    num_blocks = head_dim // 2  # For RoPE, head_dim must be even
    
    # ========== MHA Processing ==========
    for l in range(config.num_hidden_layers):
        attn0 = layers0[l].self_attn
        
        # Get weights in the same way as reference code
        Wq0 = _to_fp32_cpu(attn0.q_proj.weight)
        Wk0 = _to_fp32_cpu(attn0.k_proj.weight)
        Wv0 = _to_fp32_cpu(attn0.v_proj.weight)
        Wo0 = _to_fp32_cpu(attn0.o_proj.weight)
        
        # Reshape: (num_heads * head_dim, hidden_size) -> (num_heads, head_dim, hidden_size)
        Q0 = Wq0.view(config.num_attention_heads, head_dim, config.hidden_size)
        K0 = Wk0.view(config.num_attention_heads, head_dim, config.hidden_size)
        V0 = Wv0.view(config.num_attention_heads, head_dim, config.hidden_size)
        O0 = Wo0.view(config.hidden_size, config.num_attention_heads, head_dim).permute(1, 2, 0)
        
        # Process each head and each block within the head
        for i in range(config.num_attention_heads):
            Q_head = Q0[i]  # (head_dim, hidden_size)
            K_head = K0[i]  # (head_dim, hidden_size)
            
            for j in range(num_blocks):
                # Extract j-th 2D block: dimensions [j, j+head_dim/2]
                Q_block = torch.stack([Q_head[j, :], Q_head[j + num_blocks, :]], dim=0)  # (2, hidden_size)
                K_block = torch.stack([K_head[j, :], K_head[j + num_blocks, :]], dim=0)  # (2, hidden_size)
                
                # Compute conservation: ||Q||_F^2 - ||K||_F^2 (scalar)
                Q_frob_sq = torch.sum(Q_block ** 2).item()
                K_frob_sq = torch.sum(K_block ** 2).item()
                h0 = Q_frob_sq - K_frob_sq  # Scalar conservation law
                
                key = f"attn_conservation_layer_{l}_head_{i}_block_{j}"
                h0_values[key] = h0
                
                # MHA Group 1: ||Q_block[0]||_2^2 + ||K_block[0]||_2^2 - ||Q_block[1]||_2^2 - ||K_block[1]||_2^2
                Q_block_0_norm_sq = torch.sum(Q_block[0] ** 2).item()
                K_block_0_norm_sq = torch.sum(K_block[0] ** 2).item()
                Q_block_1_norm_sq = torch.sum(Q_block[1] ** 2).item()
                K_block_1_norm_sq = torch.sum(K_block[1] ** 2).item()
                h0_values[f"{key}_mha_group_1_h0"] = Q_block_0_norm_sq + K_block_0_norm_sq - Q_block_1_norm_sq - K_block_1_norm_sq
    
    # ========== MoE Gating Processing ==========
    if config.ffn_type in ["dmoe", "smoe"]:
        for l in range(config.num_hidden_layers):
            gate0 = _to_fp32_cpu(layers0[l].ffn.router.weight)
            
            # Conservation law: column sums
            # gate0 shape: (num_experts, hidden_size)
            h0 = gate0.sum(dim=0)  # (hidden_size,)
            
            key = f"moe_conservation_layer_{l}"
            h0_values[key] = h0.clone()
            
            # Store initial Gating Group 1: ||W||_F (Frobenius norm)
            h0_values[f"{key}_gating_group_1_h0"] = copy.deepcopy(gate0[0,:])
    
    return h0_values


@torch.no_grad()
def conservation_log(global_step, cached_h0, current_model, log_type, config, args):
    conservation_error_sum = 0.0
    conservation_error_count = 0
    attn_conservation_error_sum = 0.0
    attn_conservation_error_count = 0
    moe_conservation_error_sum = 0.0
    moe_conservation_error_count = 0
    metrics = {}
    epsilon = 1e-9
    
    # Track non-conservation quantities for averaging
    mha_group_1_list = []
    gating_group_1_list = []
    
    layersk = current_model.model.layers
    head_dim = config.hidden_size // config.num_attention_heads
    num_blocks = head_dim // 2
    
    # ========== MHA Processing ==========
    for l in range(config.num_hidden_layers):
        attnk = layersk[l].self_attn
        
        # Get current weights in the same way as reference code
        Wqk = _to_fp32_cpu(attnk.q_proj.weight)
        Wkk = _to_fp32_cpu(attnk.k_proj.weight)
        Wvk = _to_fp32_cpu(attnk.v_proj.weight)
        Wok = _to_fp32_cpu(attnk.o_proj.weight)
        
        # Reshape: (num_heads * head_dim, hidden_size) -> (num_heads, head_dim, hidden_size)
        Qk = Wqk.view(config.num_attention_heads, head_dim, config.hidden_size)
        Kk = Wkk.view(config.num_attention_heads, head_dim, config.hidden_size)
        Vk = Wvk.view(config.num_attention_heads, head_dim, config.hidden_size)
        Ok = Wok.view(config.hidden_size, config.num_attention_heads, head_dim).permute(1, 2, 0)
        
        # Process each head and each block
        for i in range(config.num_attention_heads):
            Q_head = Qk[i]  # (head_dim, hidden_size)
            K_head = Kk[i]  # (head_dim, hidden_size)
            
            for j in range(num_blocks):
                # Extract j-th 2D block: dimensions [j, j+head_dim/2]
                Q_block = torch.stack([Q_head[j, :], Q_head[j + num_blocks, :]], dim=0)  # (2, hidden_size)
                K_block = torch.stack([K_head[j, :], K_head[j + num_blocks, :]], dim=0)  # (2, hidden_size)
                
                # ========== Conservation Law ==========
                # Compute current conservation value: ||Q||_F^2 - ||K||_F^2 (scalar)
                Q_frob_sq = torch.sum(Q_block ** 2).item()
                K_frob_sq = torch.sum(K_block ** 2).item()
                h_current = Q_frob_sq - K_frob_sq  # Scalar conservation law
                
                # Get initial value
                key = f"attn_conservation_layer_{l}_head_{i}_block_{j}"
                
                if key in cached_h0:
                    h0 = cached_h0[key]
                    
                    # Compute relative error: ||h(t) - h(0)||_2 / ||h(0)||_2
                    # For scalars: |x| = ||x||_2 (L2 norm of scalar is absolute value)
                    relative_error = abs(h_current - h0) / (abs(h0) + epsilon)
                    
                    # Accumulate for averaging
                    conservation_error_sum += relative_error
                    conservation_error_count += 1
                    attn_conservation_error_sum += relative_error
                    attn_conservation_error_count += 1
                    
                    # Store error for this conservation law
                    metrics[f"{key}_error"] = relative_error
                
                # ========== MHA Group 1 (Per Block) ==========
                # ||Q_block[0]||_2^2 + ||K_block[0]||_2^2 - ||Q_block[1]||_2^2 - ||K_block[1]||_2^2
                Q_block_0_norm_sq = torch.sum(Q_block[0] ** 2).item()
                K_block_0_norm_sq = torch.sum(K_block[0] ** 2).item()
                Q_block_1_norm_sq = torch.sum(Q_block[1] ** 2).item()
                K_block_1_norm_sq = torch.sum(K_block[1] ** 2).item()
                mha_group_1_val = Q_block_0_norm_sq + K_block_0_norm_sq - Q_block_1_norm_sq - K_block_1_norm_sq
                metrics[f"attn_layer_{l}_head_{i}_block_{j}_mha_group_1"] = mha_group_1_val
                
                # Compute relative change
                h0_key = f"attn_conservation_layer_{l}_head_{i}_block_{j}_mha_group_1_h0"
                if h0_key in cached_h0:
                    h0_val = cached_h0[h0_key]
                    # For scalars: |x - y| = ||x - y||_2 (L2 norm)
                    relative_change = abs(mha_group_1_val - h0_val) / (abs(h0_val) + epsilon)
                    mha_group_1_list.append(relative_change)
    
    # ========== MoE Gating Processing ==========
    if config.ffn_type in ["dmoe", "smoe"]:
        for l in range(config.num_hidden_layers):
            gatek = _to_fp32_cpu(layersk[l].ffn.router.weight)
            
            # ========== Conservation Law ==========
            # Compute current conservation value
            h_current = gatek.sum(dim=0)  # (hidden_size,)
            
            key = f"moe_conservation_layer_{l}"
            
            if key in cached_h0:
                h0 = cached_h0[key]
                
                # Compute relative error: ||h(t) - h(0)||_2 / ||h(0)||_2
                # h_current and h0 are vectors (hidden_size,)
                diff = h_current - h0
                diff_norm = torch.norm(diff, p=2).item()
                h0_norm = torch.norm(h0, p=2).item()
                relative_error = diff_norm / (h0_norm + epsilon)
                
                # Accumulate for averaging
                conservation_error_sum += relative_error
                conservation_error_count += 1
                moe_conservation_error_sum += relative_error
                moe_conservation_error_count += 1
                
                # Store error for this conservation law
                metrics[f"{key}_error"] = relative_error
            
            # ========== Gating Group 1 ==========
            gating_group_1_val = gatek[0,:]
            metrics[f"moe_layer_{l}_gating_group_1"] = gating_group_1_val
            
            # Compute relative change
            h0_key = f"moe_conservation_layer_{l}_gating_group_1_h0"
            if h0_key in cached_h0:
                h0_val = cached_h0[h0_key]
                # For scalars: |x - y| = ||x - y||_2 (L2 norm)
                relative_change = abs(gating_group_1_val - h0_val) / (abs(h0_val) + epsilon)
                gating_group_1_list.append(relative_change)
    
    # ========== Compute Average Conservation Error ==========
    if conservation_error_count > 0:
        avg_conservation_error = conservation_error_sum / conservation_error_count
        metrics['num_conservation_laws'] = conservation_error_count
    else:
        avg_conservation_error = 0.0
        metrics['num_conservation_laws'] = 0
    
    metrics['avg_conservation_error'] = avg_conservation_error
    
    # Compute separate averages for attention and MoE
    if attn_conservation_error_count > 0:
        metrics['avg_mha_conservation_error'] = attn_conservation_error_sum / attn_conservation_error_count
    
    if moe_conservation_error_count > 0:
        metrics['avg_moe_conservation_error'] = moe_conservation_error_sum / moe_conservation_error_count
    
    # Compute averages for non-conservation quantities
    if len(mha_group_1_list) > 0:
        metrics['avg_mha_group_1'] = float(np.mean(mha_group_1_list))
    
    if len(gating_group_1_list) > 0:
        metrics['avg_moe_group_1'] = float(np.mean(gating_group_1_list))
    
    # ========== Wandb Logging ==========
    try:
        import wandb
        if wandb.run is not None:
            wandb.log(metrics, step=global_step)
    except ImportError:
        pass
    
    return avg_conservation_error, metrics