import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union

import numpy as np
from einops import rearrange
from models.LinearModel import KronLinear

def group_pattern(n: int, m: int, mat: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
    """will group the matrix into n x m blocks
    
    For example:
    if  W is a matrix of size (8, 6), n = 2, m = 3
    
    then the matrix will be grouped into (4,2) blocks of size (2, 3)
    [[G1, G2],
        [G3, G4],
        [G5, G6],
        [G7, G8]]
    
    Gi = [[W[2i, 3j], W[2i, 3j+1], W[2i, 3j+2]],

    Args:
        n (int): the 
        m (int): _description_
        mat (Union[torch.Tensor, np.ndarray]): _description_

    Returns:
        torch.Tensor: _description_
    """
    mat_shape = mat.shape
    assert len(mat_shape) == 2, "The input matrix should be 2D"
    assert mat_shape[0] % n == 0 and mat_shape[1] % m == 0, "The input matrix should be divisible by n and m"
    n1 = mat_shape[0] // n
    m1 = mat_shape[1] // m
    
    mat = rearrange(mat, '(n1 n) (m1 m) -> (n1 m1) (n m)', n=n, m=m, n1=n1, m1=m1)
    return mat
    

def get_group_lasso(model, pattern='dim', *args, **kwargs):
    group_loss = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            if len(param.shape) == 4:
                if pattern == 'dim':
                    param = group_pattern(param.shape[0], param.shape[1], param)
                elif pattern == 'channel':
                    param = group_pattern(param.shape[2], param.shape[3], param)
                else:
                    raise ValueError("Pattern should be either 'dim' or 'channel'")
            if len(param.shape) == 2:
                if pattern == 'dim':
                    param = group_pattern(param.shape[0], param.shape[1], param)
                elif pattern == 'channel':
                    param = group_pattern(param.shape[1], param.shape[0], param)
                    
            group_loss += torch.norm(param, p=2)
            
        
    
    
def regularizer(model, p=1):
    reg_loss = 0
    total_params = 0
    for name, module in model._modules.items():
        if len(list(module.children())) > 0:
            reg_loss += regularizer(module, p)[0]
            total_params += regularizer(module, p)[1]
        elif isinstance(module, KronLinear):
            reg_loss += torch.norm(module.s, p=p) / module.s.numel()
            total_params += module.s.numel()
    return reg_loss, total_params

def lenet_regularizer(model, p=1):
    reg_loss = 0
    total_params = 0
    for name, module in model._modules.items():
        if len(list(module.children())) > 0:
            reg_loss += regularizer(module, p)
        elif isinstance(module, KronLinear):
            reg_loss += torch.norm(module.s, p=p) 
            total_params += module.s.numel()
    # print(reg_loss, total_params)
    return reg_loss/total_params

def regularization(model: nn.Module, mode: str):
    regu, counter = 0, 0
    for name, param in model.named_parameters():
        if "mask_scores" in name:
            if mode == "l1":
                regu += torch.norm(torch.sigmoid(param), p=1) / param.numel()
            elif mode == "l0":
                regu += torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)).sum() / param.numel()
            else:
                ValueError("Don't know this mode.")
            counter += 1
    return regu / counter

def get_s_norm(model, p=1):
    total = 0
    for name, param in model.named_parameters():
        if 's' in name:
            total += torch.norm(param, p=p)
    return total


def get_sparsity(model):
    total = 0
    zero_params = 0 
    for name, param in model.named_parameters():
        if 's' in name:
            total += param.numel()
            zero_params += torch.sum(param < 1e-5).item()
    return zero_params / total, zero_params, total

if __name__ == "__main__":
    import os,sys
    
    sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
    from models.LinearModel import KronLinear
    model = nn.ModuleList(
        [KronLinear(6,4, patchsize=[3,4], structured_sparse=True),
         KronLinear(6,4, patchsize=[3,4], structured_sparse=True),
         KronLinear(6,4, patchsize=[3,4], structured_sparse=True),
         ])
    print(model[0].s.shape)
    print(*model.named_parameters())
    print(get_sparsity(model))