import torch 
import torch.nn as nn
import torch.nn.functional as F


import numpy as np

from models.LinearModel  import KronLinear

def get_kronlinear_sparsity(model:nn.Module, threshold=1e-3):
    """get the sparsity of a model which have KronLinear layers, only calculate the sparsity of the s parameter

    Args:
        model (nn.Module): The model to calculate the sparsity
        threshold (float, optional): The threshold to determine the sparsity. Defaults to 1e-3.
    Returns:
        sparsity: The sparsity of the model. Calculated as the number of s parameters that are less than the threshold divided by the total number of s parameters
    """
    sparsity = 0
    total = 0
    for name, module in model._modules.items():
        if isinstance(module, KronLinear):
            if module.s is not None:
                sparsity += torch.sum(torch.abs(module.s) < threshold).item()
                total += module.s.numel()
    return sparsity/total




def get_param_num(model):
    """Get the number of parameters of a model

    Args:
        model (nn.Module): The model to calculate the number of parameters

    Returns:
        int: The number of parameters of the model
    """
    total = 0
    for name, param in model.named_parameters():
        total += param.numel()
        
    return total


def test_accuracy(model, test_loader, transform=None, device='cpu'):
    """Test the accuracy of a model on a test data loader

    Args:
        model (nn.Module): The model to test
        test_data_loader (DataLoader): The test data loader
        
    """
    correct = 0 
    total = 0
    for _, (x, y) in enumerate(test_loader):
        x = x.to(device)
        y = y.to(device)
        if transform is not None:
            x = transform(x)
        
        outputs = model(x)
        total += y.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y).sum().item() 
    return correct/total
    