"""
Synaptic Pruning for Spiking Neural Networks

Implements activity-dependent synaptic pruning based on temporal cosine similarity,
a key contribution of the paper for refining unsupervised spiking representations.

The pruning criterion is derived from cosine similarity between original and pruned
weight vectors, which preserves the direction of representational geometry while
removing redundant synaptic connections.

Usage:
    python main_prune.py --spiking --model resnet18 --model-path <path> --prune-rate 0.3
"""

from __future__ import print_function

import os
import sys
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.nn.utils.prune as prune

from networks.resnet_ann import SupConResNet
from networks.resnet_snn import SupConResNetSNN
from torch.nn.utils import parameters_to_vector as Params2Vec


def vectorise_model(model):
    """
    Convert model parameters to a single vector form.
    
    Args:
        model: PyTorch model
        
    Returns:
        1D tensor containing all model parameters
    """
    return Params2Vec(model.parameters())


def cosine_similarity(base_weights, model_weights):
    """
    Calculate the cosine similarity between two weight vectors.
    
    This metric quantifies directional alignment between original and pruned
    synaptic weight vectors. From Proposition 1 in the paper, maximizing
    cosine similarity is equivalent to preserving representational geometry.
    
    Args:
        base_weights: Original weight vector
        model_weights: Pruned weight vector
        
    Returns:
        Cosine similarity value in [-1, 1]
    """
    return torch.nan_to_num(torch.clip(torch.dot(
        base_weights, model_weights
    ) / (
        torch.linalg.norm(base_weights)
        * torch.linalg.norm(model_weights)
    ), -1, 1), 0)


def global_prune_without_masks(model, amount):
    """
    Apply global unstructured pruning based on L1 magnitude.
    
    Synapses with the smallest absolute weights are removed, which according
    to Proposition 1, preserves the direction of the weight vector to the
    greatest extent (maximizes cosine similarity).
    
    Args:
        model: PyTorch model to prune
        amount: Fraction of parameters to prune (0 to 1)
    """
    # Collect all prunable parameters
    parameters_to_prune = []
    for mod in model.modules():
        if hasattr(mod, "weight"):
            if isinstance(mod.weight, torch.nn.Parameter):
                parameters_to_prune.append((mod, "weight"))
        if hasattr(mod, "bias"):
            if isinstance(mod.bias, torch.nn.Parameter):
                parameters_to_prune.append((mod, "bias"))
    parameters_to_prune = tuple(parameters_to_prune)
    
    # Apply global L1 unstructured pruning
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )
    
    # Remove pruning reparameterization (make pruning permanent)
    for mod in model.modules():
        if hasattr(mod, "weight_orig"):
            if isinstance(mod.weight_orig, torch.nn.Parameter):
                prune.remove(mod, "weight")
        if hasattr(mod, "bias_orig"):
            if isinstance(mod.bias_orig, torch.nn.Parameter):
                prune.remove(mod, "bias")


def parse_option():
    parser = argparse.ArgumentParser('Synaptic Pruning')

    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4, help='Number of timesteps')
    parser.add_argument('--model', type=str, default='resnet18',
                        choices=['resnet18', 'resnet34', 'resnet50'])
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'tinyimagenet'])
    parser.add_argument('--model-path', type=str, required=True,
                        help='Path to pre-trained model checkpoint')
    parser.add_argument('--prune-rate', type=float, default=0.3,
                        help='Fraction of synapses to prune (0 to 1)')
    parser.add_argument('--gpu-id', default='0', type=str)

    return parser.parse_args()


def set_model(opt):
    """Load pre-trained model for pruning."""
    if opt.spiking:
        model = SupConResNetSNN(name=opt.model, timestep=opt.timesteps)
        model.encoder = torch.nn.DataParallel(model.encoder)
        checkpoint = torch.load(opt.model_path, map_location="cpu")
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint["state_dict"])
        elif 'model' in checkpoint:
            model.load_state_dict(checkpoint["model"])
    else:
        model = SupConResNet(name=opt.model)
        model.encoder = torch.nn.DataParallel(model.encoder)
        checkpoint = torch.load(opt.model_path, map_location="cpu")
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint["state_dict"])
        elif 'model' in checkpoint:
            model.load_state_dict(checkpoint["model"])
    
    print('Successfully loaded pretrained model')
    return model


def save_model(model, opt, cosine_sim, save_file):
    """Save pruned model with metadata."""
    if not os.path.isdir(os.path.dirname(save_file)):
        os.makedirs(os.path.dirname(save_file))
        
    torch.save({
        'state_dict': model.state_dict(),
        'prune_rate': opt.prune_rate,
        'cosine_sim': cosine_sim,
    }, save_file)


def count_parameters(model):
    """Count total and non-zero parameters."""
    total = 0
    nonzero = 0
    for param in model.parameters():
        total += param.numel()
        nonzero += torch.count_nonzero(param).item()
    return total, nonzero


def main():
    opt = parse_option()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id

    # Load model
    model = set_model(opt)
    model.to("cpu")
    
    # Count parameters before pruning
    total_before, nonzero_before = count_parameters(model)
    print(f'Parameters before pruning: {total_before:,} total, {nonzero_before:,} non-zero')
    
    # Save original model vector for cosine similarity calculation
    model_vec = vectorise_model(model)
    
    # Perform pruning
    print(f'\nPruning model with prune ratio: {opt.prune_rate}')
    global_prune_without_masks(model, opt.prune_rate)
    
    # Count parameters after pruning
    total_after, nonzero_after = count_parameters(model)
    actual_prune_rate = 1 - (nonzero_after / nonzero_before)
    print(f'Parameters after pruning: {total_after:,} total, {nonzero_after:,} non-zero')
    print(f'Actual pruning ratio: {actual_prune_rate:.4f}')
    
    # Calculate cosine similarity (measure of representational preservation)
    prune_model_vec = vectorise_model(model)
    cosine_sim = cosine_similarity(model_vec, prune_model_vec).item()
    print(f'Cosine similarity after pruning: {cosine_sim:.4f}')
    print(f'(Higher cosine similarity indicates better preservation of representational geometry)')
    
    # Save pruned model
    save_dir = os.path.dirname(opt.model_path)
    save_file = os.path.join(save_dir, f'pruned_rate{opt.prune_rate}.pth.tar')

    save_model(model, opt, cosine_sim, save_file)
    print(f'\nPruned model saved to: {save_file}')


if __name__ == '__main__':
    main()
