import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import torch.nn as nn
from torch.func import functional_call, vmap, jacrev
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import seaborn as sns
from collections import defaultdict

device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'

# Names for display
PRUNER_NAMES = {
    'rand': 'Random',
    'random': 'Random',
    'snip': 'SNIP',
    'synflow': 'Synflow',
    'grasp': 'GraSP',
}

device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'

def compute_ntk_spectral_metrics(kernel_matrix):
    """
    Compute spectral metrics for a kernel matrix as described in the document.
    
    Args:
        kernel_matrix: NxN kernel matrix
        
    Returns:
        Dictionary of spectral metrics
    """
    # Ensure symmetric matrix for reliable eigendecomposition
    if isinstance(kernel_matrix, torch.Tensor):
        kernel_matrix = kernel_matrix.cpu().numpy()
    
    # Make sure it's symmetric (fix numerical issues)
    kernel_matrix = (kernel_matrix + kernel_matrix.T) / 2
    
    # Compute eigenvalues (in descending order)
    eigenvalues = np.linalg.eigvalsh(kernel_matrix)[::-1]
    total_eigensum = eigenvalues.sum()
    
    # Compute decay rate by fitting a power law λ_k ∝ k^(-α)
    # Use log-log linear regression on sorted eigenvalues
    k_values = np.arange(1, len(eigenvalues) + 1)
    valid_indices = eigenvalues > 1e-10  # Avoid numerical issues with very small eigenvalues
    
    if np.sum(valid_indices) > 10:  # Need sufficient points for meaningful fit
        try:
            # Fit power law: λ_k = C * k^(-α)
            def power_law(x, C, alpha):
                return C * (x ** (-alpha))
            
            params, _ = curve_fit(power_law, k_values[valid_indices], 
                                eigenvalues[valid_indices], 
                                bounds=([0, 0], [np.inf, np.inf]))
            decay_exponent = params[1]
        except:
            # Fallback: estimate decay from ratio of first/last eigenvalue
            decay_exponent = np.log(eigenvalues[0] / eigenvalues[valid_indices][-1]) / np.log(len(valid_indices))
    else:
        decay_exponent = np.nan
    
    # Compute effective rank
    effective_rank = total_eigensum / eigenvalues[0]
    
    # Compute cumulative energy distribution
    cumulative_energy = np.cumsum(eigenvalues) / total_eigensum
    
    # Compute energy concentration metrics
    energy_in_top1 = eigenvalues[0] / total_eigensum
    energy_in_top5 = np.sum(eigenvalues[:5]) / total_eigensum
    energy_in_top10 = np.sum(eigenvalues[:10]) / total_eigensum
    
    # Compute spectral gap and outliers
    spectral_gap = eigenvalues[0] / eigenvalues[1] if len(eigenvalues) > 1 else np.inf
    
    # Count outliers
    num_outliers = np.sum(eigenvalues > 0.1 * eigenvalues[0])
    
    return {
        'eigenvalues': eigenvalues,
        'decay_exponent': decay_exponent,
        'effective_rank': effective_rank,
        'cumulative_energy': cumulative_energy,
        'energy_in_top1': energy_in_top1,
        'energy_in_top5': energy_in_top5,
        'energy_in_top10': energy_in_top10,
        'spectral_gap': spectral_gap,
        'num_outliers': num_outliers
    }

def plot_eigenvalue_decay(metrics_dict, method_name, ax=None, log_scale=True):
    """
    Plot eigenvalue decay on log-log or semi-log scale.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    eigenvalues = metrics_dict['eigenvalues']
    k_values = np.arange(1, len(eigenvalues) + 1)
    
    if log_scale:
        ax.loglog(k_values, eigenvalues, label=method_name)
        
    else:
        ax.semilogy(k_values, eigenvalues, label=method_name)
    
    ax.set_xlabel('Index $k$')
    ax.set_ylabel('Eigenvalue $\lambda_k$')
    ax.set_title(f'Eigenvalue Decay (α ≈ {metrics_dict["decay_exponent"]:.2f})')
    
    # Add text annotation with key metrics
    text = (f"Eff. Rank: {metrics_dict['effective_rank']:.2f}\n"
            f"Spec. Gap: {metrics_dict['spectral_gap']:.2f}\n"
            f"Top-5 Energy: {metrics_dict['energy_in_top5']*100:.1f}%")
    
    ax.text(0.05, 0.05, text, transform=ax.transAxes, 
            bbox=dict(facecolor='white', alpha=0.7))
    
    return ax

def plot_cumulative_energy(metrics_dict, method_name, ax=None):
    """
    Plot cumulative energy distribution.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    cum_energy = metrics_dict['cumulative_energy']
    k_values = np.arange(1, len(cum_energy) + 1)
    
    ax.plot(k_values, cum_energy, label=method_name)
    
    # Add horizontal lines at 50%, 75%, 90%, 95%
    for threshold in [0.5, 0.75, 0.9, 0.95]:
        ax.axhline(threshold, color='gray', linestyle='--', alpha=0.3)
        # Find index where energy exceeds threshold
        idx = np.argmax(cum_energy >= threshold)
        ax.text(len(cum_energy)*0.95, threshold, f'{idx+1} modes', 
                verticalalignment='bottom', horizontalalignment='right')
    
    ax.set_xlabel('Number of Eigenvalues')
    ax.set_ylabel('Cumulative Energy Fraction')
    ax.set_title('Spectral Energy Concentration')
    ax.set_xlim(1, min(50, len(cum_energy)))  # Focus on top 50 eigenvalues
    ax.set_ylim(0, 1)
    
    return ax



class GraphonLinear(nn.Module):
    """Linear layer with graphon-structured connectivity"""
    def __init__(self, in_features, out_features, graphon_fn, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.graphon_fn = graphon_fn
        
        # Standard parameter initialization
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        # Create graphon mask
        self.register_buffer('mask', self._create_graphon_mask())
        self.reset_parameters()
    
    def _create_graphon_mask(self):
        """Generate mask based on graphon function"""
        if self.graphon_fn is None:
            return torch.ones((self.out_features, self.in_features))
        mask = self.graphon_fn(self.out_features, self.in_features)
        return mask
    
    def reset_parameters(self):
        """Initialize weights with appropriate scaling"""
        fan_in = self.in_features
        nn.init.normal_(self.weight) / fan_in
        if self.bias is not None:
            # fan_in = self.in_features
            bound = 1 / np.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, input):
        # Apply graphon mask to weights
        masked_weight = self.weight * self.mask
        return nn.functional.linear(input, masked_weight, self.bias)

class GraphonMLP(nn.Module):
    """MLP with graphon-structured connectivity"""
    def __init__(self, layer_sizes, graphon_fns, activation=nn.ReLU()):
        super().__init__()
        self.layers = nn.ModuleList()
        self.activation = activation
        
        # Create layers with graphon connectivity
        for i in range(len(layer_sizes) - 1):
            self.layers.append(
                GraphonLinear(layer_sizes[i], layer_sizes[i+1], graphon_fns[i])
            )
    
    def forward(self, x):
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.activation(x)
        x = self.layers[-1](x)
        return x

def compute_empirical_graphon_ntk(model, x):
    """
    Compute the empirical Graphon NTK for input data.
    
    Args:
        model: The neural network model with graphon structure
        x: Input data tensor
    
    Returns:
        NTK matrix
    """
    # Detach parameters since we won't be calling backward()
    params = {k: v.detach() for k, v in model.named_parameters()}
    
    # Create single sample prediction function
    def fnet_single(params, x):
        return functional_call(model, params, (x.unsqueeze(0),)).squeeze(0)
    
    # Compute Jacobian
    jac = vmap(jacrev(fnet_single), (None, 0))(params, x)
    jac = jac.values()
    jac = [j.flatten(2) for j in jac]
    
    # Compute NTK: J(x) @ J(x).T
    ntk = torch.stack([torch.einsum('Naf,Mbf->NMab', j, j) for j in jac])
    ntk = ntk.sum(0)
    
    # Average over output dimensions to get a standard 2D kernel matrix
    ntk = ntk.sum((-2, -1)) / (ntk.size(-1) * ntk.size(-2))
    
    return ntk


def generate_graphon_functions(method, sparsity):
    """
    Generate graphon functions for different pruning methods at specified sparsity.
    
    Args:
        method: Pruning method ('random', 'block', 'magnitude', 'synflow')
        sparsity: Sparsity level (0.0 to 1.0, where 1.0 means completely sparse)
        
    Returns:
        Graphon function
    """
    
    if method == 'random':
        # Erdős-Rényi graphon
        def er_graphon(out_features, in_features, sparsity=sparsity):
            # mask = (torch.abs(torch.rand(out_features, in_features)) > sparsity).float()
            p = torch.ones((out_features, in_features)) * (1 - sparsity)
            mask = torch.bernoulli(p)
            return mask
        return er_graphon
    

    elif method == 'synflow':
        from histogram import generate_mask_from_histogram
        if sparsity <= 0.75:
            hist_path = '../histogram/histogram_results/fc_synflow_L_4_N_2000_compression_0.5_sorted/histogram_synflow.npy'
        elif sparsity <= 0.85 and sparsity > 0.75:
            hist_path = '../histogram/histogram_results/fc_synflow_L_4_N_2000_compression_0.75_sorted/histogram_synflow.npy'
        elif sparsity > 0.85:
            hist_path = '../histogram/histogram_results/fc_synflow_L_4_N_2000_compression_1.0_sorted/histogram_synflow.npy'
        
        histogram = np.load(f'{hist_path}')
        def synflow_graphon(out_features, in_features, sparsity=sparsity):
            mask = generate_mask_from_histogram(histogram, shape=(out_features, in_features), sparsity=sparsity)
            mask = torch.tensor(mask, dtype=torch.float32)
            return mask
        return synflow_graphon
    
    
    elif method == 'snip':
        from histogram import generate_mask_from_histogram
        if sparsity <= 0.75:
            hist_path = '../histogram/histogram_results/fc_snip_L_4_N_2000_compression_0.5_sorted/histogram_snip.npy'
        elif sparsity <= 0.85 and sparsity > 0.75:
            hist_path = '../histogram/histogram_results/fc_snip_L_4_N_2000_compression_0.75_sorted/histogram_snip.npy'
        elif sparsity > 0.85:
            hist_path = '../histogram/histogram_results/fc_snip_L_4_N_2000_compression_1.0_sorted/histogram_snip.npy'
        
        histogram = np.load(f'{hist_path}')
        def snip_graphon(out_features, in_features, sparsity=sparsity):
            mask = generate_mask_from_histogram(histogram, shape=(out_features, in_features), sparsity=sparsity)
            mask = torch.tensor(mask, dtype=torch.float32)
            return mask
        return snip_graphon
    
    elif method == 'grasp':
        from histogram import generate_mask_from_histogram
        if sparsity <= 0.75:
            hist_path = '../histogram/histogram_results/fc_grasp_L_4_N_2000_compression_0.5_sorted/histogram_grasp.npy'
        elif sparsity <= 0.85 and sparsity > 0.75:
            hist_path = '../histogram/histogram_results/fc_grasp_L_4_N_2000_compression_0.75_sorted/histogram_grasp.npy'
        elif sparsity > 0.85:
            hist_path = '../histogram/histogram_results/fc_grasp_L_4_N_2000_compression_1.0_sorted/histogram_grasp.npy'
        
        histogram = np.load(f'{hist_path}')
        def grasp_graphon(out_features, in_features, sparsity=sparsity):
            mask = generate_mask_from_histogram(histogram, shape=(out_features, in_features), sparsity=sparsity)
            mask = torch.tensor(mask, dtype=torch.float32)
            return mask
        return grasp_graphon
    


    
    else:
        # Default: no pruning (dense graphon)
        def dense_graphon(out_features, in_features):
            return torch.ones((out_features, in_features))
        return dense_graphon





def run_sparsity_experiment(dataset='mnist', sample_size=500, sparsity_levels=None, methods=None, 
                           hidden_dim=128, saved_folders='gntk_results', seeds=None):
    """
    Run experiment comparing spectral properties across pruning methods and sparsity levels,
    with multiple random seeds for statistical significance.
    
    Args:
        dataset: 'mnist' or 'cifar10'
        sample_size: Number of samples to use
        sparsity_levels: List of sparsity levels to test
        methods: List of pruning methods to compare
        hidden_dim: Dimension of hidden layers
        saved_folders: Path to save results
        seeds: List of random seeds to use (default [0, 1, 2, 3, 4])
    
    Returns:
        Dictionary with results
    """
    if sparsity_levels is None:
        sparsity_levels = [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
    
    if methods is None:
        methods = ['random', 'synflow', 'snip']
        
    if seeds is None:
        # seeds = [0, 1, 2, 3, 4]  # Default to 5 seeds
        seeds = list(np.arange(20))
    
    # Results storage - now with an additional level for seeds
    all_results = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    # Structure: all_results[method][sparsity][seed] = metrics_dict
    
    saved_folders = f'./{saved_folders}/{dataset}_dim_{hidden_dim}'
    os.makedirs(saved_folders, exist_ok=True)
    
    # For each seed, method and sparsity level
    for seed_idx, seed in enumerate(seeds):
        print(f"Running with seed {seed} ({seed_idx+1}/{len(seeds)})...")
        
        # Set random seeds for reproducibility
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Load and prepare data
        if dataset.lower() == 'mnist':
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
            data = datasets.MNIST('./data', train=True, download=True, transform=transform)
            
            # Use subset for faster computation
            indices = torch.randperm(len(data))[:sample_size]
            data_subset = torch.utils.data.Subset(data, indices)
            
            # Extract features and labels
            loader = torch.utils.data.DataLoader(data_subset, batch_size=sample_size)
            X, y = next(iter(loader))
            X = X.view(X.size(0), -1).to(device)  # Flatten images
            input_dim = X.shape[1]  # 28*28 = 784
            
        elif dataset.lower() == 'cifar10':
            # [CIFAR10 loading code - keep as is]
            pass
        
        # Define network architecture
        layer_sizes = [input_dim, hidden_dim, hidden_dim, 10]
        
        # For each method and sparsity level
        for method in methods:
            for sparsity in sparsity_levels:
                print(f"Computing NTK for {method} pruning at sparsity {sparsity:.2f} (seed {seed})...")
                
                # Generate graphon function
                graphon_fn = generate_graphon_functions(method, sparsity)
                
                # Create model with this graphon pattern
                model = GraphonMLP(
                    layer_sizes=layer_sizes,
                    graphon_fns=[graphon_fn for _ in range(len(layer_sizes)-1)]
                ).to(device)
                
                # Compute NTK
                ntk = compute_empirical_graphon_ntk(model, X)
                
                # Compute spectral metrics
                metrics = compute_ntk_spectral_metrics(ntk)
                
                # Store results for this seed
                all_results[method][sparsity][seed] = metrics
                
                # Only save individual seed visualizations for first seed to avoid clutter
                if seed == seeds[0]:
                    # Create figure with two subplots: kernel visualization and eigenvalue decay
                    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
                    
                    # Plot kernel matrix
                    im = ax1.imshow(ntk.cpu().numpy(), cmap='viridis')
                    ax1.set_title(f'{method.capitalize()} Pruning (s={sparsity:.2f})')
                    fig.colorbar(im, ax=ax1)
                    
                    # Plot eigenvalue decay
                    plot_eigenvalue_decay(metrics, method, ax=ax2)
                    
                    plt.tight_layout()
                    plt.savefig(f'{saved_folders}/{dataset}_{method}_sparsity{sparsity:.2f}_seed{seed}.png')
                    plt.close()
    
    # Process results to calculate means and standard deviations
    aggregated_results = process_multi_seed_results(all_results, seeds)
    
    # Save aggregated results
    np.save(f'{saved_folders}/aggregated_results.npy', aggregated_results)
    
    # Create comparison visualizations with error bars
    visualize_aggregated_results(aggregated_results, methods, sparsity_levels, saved_folders, dataset)
    
    return {
        'raw_results': all_results,
        'aggregated_results': aggregated_results
    }

def process_multi_seed_results(all_results, seeds):
    """
    Process results from multiple seeds to calculate means and standard deviations.
    
    Args:
        all_results: Nested dictionary with results for each method, sparsity, and seed
        seeds: List of seeds used
    
    Returns:
        Dictionary with aggregated results
    """
    aggregated = {}
    
    for method in all_results:
        aggregated[method] = {}
        
        for sparsity in all_results[method]:
            aggregated[method][sparsity] = {
                'mean': {},
                'std': {}
            }
            
            # For each metric, calculate mean and std across seeds
            metrics_to_aggregate = ['decay_exponent', 'effective_rank', 'spectral_gap', 
                                   'energy_in_top1', 'energy_in_top5', 'energy_in_top10', 
                                   'num_outliers']
            
            for metric in metrics_to_aggregate:
                values = [all_results[method][sparsity][seed][metric] for seed in seeds]
                aggregated[method][sparsity]['mean'][metric] = np.mean(values)
                aggregated[method][sparsity]['std'][metric] = np.std(values)
            
            # Special handling for eigenvalues (just keep mean)
            eig_lengths = [len(all_results[method][sparsity][seed]['eigenvalues']) for seed in seeds]
            min_length = min(eig_lengths)
            
            eigenvalues = np.array([all_results[method][sparsity][seed]['eigenvalues'][:min_length] 
                                   for seed in seeds])
            aggregated[method][sparsity]['mean']['eigenvalues'] = np.mean(eigenvalues, axis=0)
            aggregated[method][sparsity]['std']['eigenvalues'] = np.std(eigenvalues, axis=0)
    
    return aggregated

def visualize_aggregated_results(aggregated_results, methods, sparsity_levels, saved_folders, dataset):
    """
    Create visualizations of aggregated results with error bars.
    
    Args:
        aggregated_results: Dictionary with mean and std for each method and sparsity
        methods: List of pruning methods
        sparsity_levels: List of sparsity levels
        saved_folders: Path to save visualizations
        dataset: Dataset name for titles
    """
    # Create figure with multiple subplots for different metrics
    fig, axes = plt.subplots(1, 4, figsize=(12, 3))
    axes = axes.flatten()
    
    # Define metrics to plot
    metrics = [
        {'name': 'decay_exponent', 'title': 'Eigenvalue Decay Rate', 'ylabel': 'Decay Exponent α', 'idx': 0},
        {'name': 'effective_rank', 'title': 'Effective Rank', 'ylabel': 'Effective Rank', 'idx': 1},
        {'name': 'spectral_gap', 'title': 'Spectral Gap', 'ylabel': 'Spectral Gap (λ₁/λ₂)', 'idx': 2},
        {'name': 'energy_in_top5', 'title': 'Energy Concentration', 'ylabel': 'Energy in Top-5 Eigenvalues', 'idx': 3}
    ]
    
    colors = {'random': 'blue', 'snip': 'orange', 'synflow': 'green', 'grasp': 'red'}
    
    for metric_info in metrics:
        ax = axes[metric_info['idx']]
        
        for method in methods:
            # Extract data for this method and metric
            x_values = sorted(sparsity_levels)
            y_means = [aggregated_results[method][s]['mean'][metric_info['name']] for s in x_values]
            y_stds = [aggregated_results[method][s]['std'][metric_info['name']] for s in x_values]
            
            # Plot with error bars
            ax.errorbar(x_values, y_means, yerr=y_stds, 
                       fmt='o-', label=PRUNER_NAMES.get(method, method), capsize=4, color=colors.get(method, None))
        
        ax.set_xlabel('Sparsity Level')
        ax.set_ylabel(metric_info['ylabel'])
        ax.set_title(metric_info['title'])
        # ax.legend()
        ax.grid(True, alpha=0.3)
        
    # Add legend to the first subplot
    if len(axes) > 0:
        axes[0].legend()
    
    plt.tight_layout()
    plt.savefig(f"{saved_folders}/{dataset}_spectral_metrics_with_error_bars.png", dpi=300)
    plt.close()
    
    # Also create log-log plots of eigenvalue decay for selected sparsity levels
    for sparsity in sparsity_levels:
        if sparsity in sparsity_levels:
            fig, ax = plt.subplots(figsize=(10, 6))
            
            for method in methods:
                eigenvalues = aggregated_results[method][sparsity]['mean']['eigenvalues']
                std_values = aggregated_results[method][sparsity]['std']['eigenvalues']
                x_values = np.arange(1, len(eigenvalues) + 1)
                
                ax.loglog(x_values, eigenvalues, 'o-', label=PRUNER_NAMES.get(method, method), color=colors.get(method, None))
                
                # Add shaded region for std dev (looks better in log-log than error bars)
                ax.fill_between(x_values, 
                               eigenvalues - std_values, 
                               eigenvalues + std_values, 
                               alpha=0.2, color=colors.get(method, None))
            
            ax.set_xlabel('Index k')
            ax.set_ylabel('Eigenvalue λₖ')
            ax.set_title(f'Eigenvalue Decay (Sparsity {sparsity:.2f})')
            ax.legend()
            ax.grid(True, which="both", alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(f"{saved_folders}/{dataset}_eigenvalue_decay_sparsity{sparsity:.2f}.png", dpi=300)
            plt.close()
    

def main():
    """Run all experiments"""
    # Set smaller sample sizes for faster computation
    # Increase these for more robust results
    
    saved_folders = 'gntk_analysis_seeds_final'
    # hidden_dims = [128, 256, 512, 1024]
    hidden_dims = [1024]
    # seeds = list(np.arange(20, 40))
    seeds = None
    dataset='mnist'
    
    for hidden_dim in hidden_dims:
        train_size = 128
        test_size = 64
        
        
        # Set sparsity levels to test
        # sparsity_levels = [0.7, 0.8, 0.9]
        sparsity_levels = [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
        # sparsity_levels = [0.8]
        
        # Set pruning methods to compare
        # methods = ['random', 'grasp', 'synflow', 'snip']
        methods = ['random', 'snip', 'synflow']
        
        # Run spectral analysis experiment
        print("Running spectral analysis experiment on MNIST...")
        spectral_results = run_sparsity_experiment(
            dataset='mnist', 
            sample_size=train_size,
            sparsity_levels=sparsity_levels,
            methods=methods,
            hidden_dim=hidden_dim,
            saved_folders=saved_folders,
            seeds=seeds
        )
        
        aggregated_results = np.load(f'{saved_folders}/{dataset}_dim_{hidden_dim}/aggregated_results.npy', allow_pickle=True).item()
        visualize_aggregated_results(aggregated_results, methods, sparsity_levels, f'{saved_folders}', dataset)
        
        print("\nExperiments completed. Results saved as PNG files.")
        

if __name__ == "__main__":
    # results = main()
    main()