from models import AntiSymm21Model, SimpleMLP21, PermEquiv21Model

def count_parameters(model):
    """
    This function takes a PyTorch model as input and returns the total
    number of parameters and the number of trainable parameters.
    
    Args:
    model (torch.nn.Module): The model to inspect
    
    Returns:
    total_params (int): Total number of parameters
    trainable_params (int): Number of trainable parameters
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return total_params, trainable_params

n = 8

models_info = [
    (AntiSymm21Model(n), f"AntiSymm21Model({n})", f"AntiSymmPermEquiv"),
    (PermEquiv21Model(n), f"PermEquiv21Model({n})", f"PermEquiv"),
    (SimpleMLP21(n), f"SimpleMLP21({n})", f"MLP"),
]


for model, model_name, _ in models_info:
    print(model_name, count_parameters(model))
