
import os
import glob
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys

# Add parent directory to path for imports
sys.path.insert(0, os.getcwd())
sys.path.insert(0, os.path.join(os.getcwd(), "train"))

from train.hyper.core import hyperActor
from train.hyper.model import MlpNetwork

def get_magnitude_profile(weight_matrix):
    """
    Computes the magnitude profile of a weight matrix.
    Profile = Sum of absolute values down the columns (Input Neuron Importance).
    Returns numpy array of shape (Fan_In,).
    """
    return torch.sum(torch.abs(weight_matrix), dim=0).detach().cpu().numpy()

def compute_heatmap_matrix(profiles_list):
    """
    Computes a heatmap matrix from a list of profiles.
    Returns: (N_Seeds, Fan_In)
    """
    if not profiles_list:
        return np.zeros((1, 1))
    
    # Return absolute magnitude (no normalization)
    return np.array(profiles_list)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ---------------------------------------------------------
    # 1. Setup & Discovery
    # ---------------------------------------------------------
    target_arch = (256, 256, 256)
    
    # GHN Checkpoints
    ghn_ckpt_pattern = os.path.join(os.getcwd(), "runs_dagger/pickcube_3x256_ghn/*/ghn_final_ckpt.pt")
    ghn_ckpts = sorted(glob.glob(ghn_ckpt_pattern))
    print(f"Searching for GHN with: {ghn_ckpt_pattern}")
    print(f"Found {len(ghn_ckpts)} GHN checkpoints.")

    # MLP Checkpoints
    mlp_ckpt_pattern = os.path.join(os.getcwd(), "runs_dagger/pickcube_3x256_mlp_baseline/*/mlp_final_ckpt.pt")
    # User might have >20 checkpoints, ensure we use same number or all available
    mlp_ckpts = sorted(glob.glob(mlp_ckpt_pattern))
    print(f"Searching for MLP with: {mlp_ckpt_pattern}")
    print(f"Found {len(mlp_ckpts)} MLP baseline checkpoints.")

    if not ghn_ckpts or not mlp_ckpts:
        print("Error: checkpoints missing. See paths above.")
        return

    # ---------------------------------------------------------
    # 2. Initialize Models
    # ---------------------------------------------------------
    # Use first GHN checkpoint to get config for dimensions
    try:
        first_ghn_ckpt = torch.load(ghn_ckpts[0], map_location=device)
        ghn_config = first_ghn_ckpt['ghn_config']
        obs_dim = ghn_config['num_observations']
        act_dim = ghn_config['num_classes']
        print(f"Obs Dim: {obs_dim}, Act Dim: {act_dim}")
    except Exception as e:
        print(f"Error loading GHN config: {e}")
        return

    # Initialize HyperActor (for GHN)
    actor = hyperActor(
        act_dim=act_dim,
        obs_dim=obs_dim,
        meta_batch_size=1,
        device=device,
        architecture_sampling_mode="uniform",
        multi_gpu=False
    )
    
    try:
        arch_idx = actor.list_of_arcs.index(target_arch)
    except ValueError:
        print(f"Error: Target architecture {target_arch} not found in actor's list.")
        return
    indices_vector = np.array([arch_idx])
    shape_inds = torch.stack([actor.list_of_shape_inds[i] for i in indices_vector])

    # ---------------------------------------------------------
    # 3. Data Extraction
    # ---------------------------------------------------------
    layer_indices = [0, 2, 4, 6]
    
    ghn_profiles = {i: [] for i in layer_indices}
    mlp_profiles = {i: [] for i in layer_indices}

    # --- Load GHN Weights ---
    print("Extracting GHN weights...")
    for ckpt_path in ghn_ckpts:
        try:
            checkpoint = torch.load(ckpt_path, map_location=device)
            actor.ghn.load_state_dict(checkpoint['ghn_state_dict'])
            actor.ghn.eval()
            
            # Predict weights
            actor.set_graph(indices_vector, shape_inds)
            model = actor.current_model[0]
            
            if hasattr(model, 'classifier'):
                container = model.classifier
            else:
                 container = model.fc if hasattr(model, 'fc') else model

            for i in layer_indices:
                 if i < len(container):
                     layer = container[i]
                     if isinstance(layer, nn.Linear):
                         prof = get_magnitude_profile(layer.weight)
                         ghn_profiles[i].append(prof)
        except Exception as e:
            print(f"Error extracting from {ckpt_path}: {e}")

    # --- Load MLP Weights ---
    print("Extracting MLP weights...")
    for ckpt_path in mlp_ckpts:
        try:
            checkpoint = torch.load(ckpt_path, map_location=device)
            state_dict = checkpoint['model_state_dict']
            
            for i in layer_indices:
                key = f"classifier.{i}.weight"
                if key in state_dict:
                     weight = state_dict[key]
                     prof = get_magnitude_profile(weight)
                     mlp_profiles[i].append(prof)
        except Exception as e:
             print(f"Error extracting from {ckpt_path}: {e}")


    # ---------------------------------------------------------
    # 4. Metrics & Visualization
    # ---------------------------------------------------------
    layer_display_indices = [0, 2, 4, 6] # L1, L2, L3, L4
    layer_names = ["Layer 1", "Layer 2", "Layer 3", "Layer 4", "Layer 2 (Zoom 0-50)"]
    
    output_filename = "ghn_vs_mlp_4layer_zoom.png"
    
    # Plotting
    # 2 Rows, 5 Columns (L1, L2, L3, L4, L2_Zoom)
    fig, axes = plt.subplots(2, 5, figsize=(30, 8), sharex='col')
    plt.subplots_adjust(wspace=0.3, hspace=0.1)
    
    models = [("GHN", ghn_profiles), ("SGD", mlp_profiles)]
    
    for row_idx, (model_name, profiles_dict) in enumerate(models):
        
        # Plot standard layers (Columns 0-3)
        for col_idx, key in enumerate(layer_display_indices):
            ax = axes[row_idx, col_idx]
            
            data = compute_heatmap_matrix(profiles_dict[key])
            im = ax.imshow(data, cmap='viridis', aspect='auto', interpolation='nearest')
            
            if row_idx == 0:
                ax.set_title(layer_names[col_idx], fontsize=28, fontweight='bold', pad=20)
            
            if col_idx == 0:
                ax.set_ylabel(f"{model_name}\n(Seeds)", fontsize=28, fontweight='bold')
            else:
                ax.set_yticks([]) # Hide inner y ticks
                
            if row_idx == 1:
                # ax.set_xlabel(f"Neuron Index", fontsize=28, fontweight='bold')
                ax.tick_params(axis='x', labelsize=18)
                
            ax.tick_params(axis='y', labelsize=18)

        # Plot 5th column (Zoom of Layer 2)
        ax = axes[row_idx, 4]
        key = 2 # Layer 2 is index 2
        data = compute_heatmap_matrix(profiles_dict[key])
        
        # Clip to 0-50
        data_zoom = data[:, 0:50]
        
        im = ax.imshow(data_zoom, cmap='viridis', aspect='auto', interpolation='nearest')

        if row_idx == 0:
             ax.set_title(layer_names[4], fontsize=28, fontweight='bold', pad=20)
        
        ax.set_yticks([])
        
        if row_idx == 1:
             # ax.set_xlabel(f"Neuron Index", fontsize=28, fontweight='bold')
             ax.tick_params(axis='x', labelsize=18)

    # common X label
    fig.text(0.5, 0.02, 'Neuron Index', ha='center', fontsize=28, fontweight='bold')
    
    print(f"Saving plot to {output_filename}...")
    plt.savefig(output_filename, dpi=150, bbox_inches='tight')
    print("Done.")

if __name__ == "__main__":
    main()
