from torch import nn

from SNN.Layers import LNM

from SNN.models.classification.VGG import LIFVGGSNN
from SNN.models.classification.ResNet import ITLIFResNet34, LNMResNet19

from SNN.LearnableMembrane import LearnableMembrane

def split_weights(net: nn.Module):
    """split network weights into to categlories,
    one are weights in conv layer and linear layer,
    others are other learnable paramters(conv bias, 
    bn weights, bn bias, linear bias)

    Args:
        net: network architecture
    
    Returns:
        a dictionary of params splite into to categlories
    """

    decay = []
    no_decay = []
    learnable_dynamics = []

    for m in net.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            decay.append(m.weight)

            if m.bias is not None:
                no_decay.append(m.bias)
        elif isinstance(m, LearnableMembrane):
            if hasattr(m, 'weight'):
                learnable_dynamics.append(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                learnable_dynamics.append(m.bias)
        elif isinstance(m, LNM):
            continue
        else:
            if hasattr(m, 'weight'):
                no_decay.append(m.weight)
            if hasattr(m, 'bias'):
                no_decay.append(m.bias)

    # print("Decay params: ", len(decay))
    # print("No decay params: ", len(no_decay))
    # print("Total params: ", len(list(net.parameters())))
 
    assert len(list(net.parameters())) == len(decay) + len(no_decay) + len(learnable_dynamics)
    
    return [
        dict(params=decay), 
        dict(params=no_decay, weight_decay=0), 
        dict(params=learnable_dynamics, weight_decay=0.0, lr=1e-2 if isinstance(net, (LIFVGGSNN, ITLIFResNet34, LNMResNet19)) else 1e-1)
    ]