import torch 
import torch.nn as nn
import torch.nn.functional as F

from models.LinearModel import KronLinear, KronConv2d

def linear2kronlinear(model,):
    pass

def convlinear2kronlinear(model, ):
    for name, module in model._modules.items():
        if isinstance(module, nn.Conv2d):
            model = KronConv2d(in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, bias=module.bias, rank=5)
        elif isinstance(module, nn.Linear):
            model = KronLinear(in_features=module.in_features, out_features=module.out_features, bias=module.bias, rank=5)


def get_group_lasso():
    pass