from models import AntiSymm20Model, SimpleMLP20, PermEquiv20Model

def count_parameters(model):
    """
    This function takes a PyTorch model as input and returns the total
    number of parameters and the number of trainable parameters.
    
    Args:
    model (torch.nn.Module): The model to inspect
    
    Returns:
    total_params (int): Total number of parameters
    trainable_params (int): Number of trainable parameters
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return total_params, trainable_params

n = 4

models_info = [
    (AntiSymm20Model(n), f"AntiSymm20Model({n})", f"AntiSymmPermEquiv"),
    (PermEquiv20Model(n), f"PermEquiv20Model({n})", f"PermEquiv"),
    (SimpleMLP20(n), f"SimpleMLP20({n})", f"MLP"),
]


for model, model_name, _ in models_info:
    print(model_name, count_parameters(model))
