import torch.nn as nn
import torch

class Cat(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x1, x2):
        return torch.cat((x1, x2), dim=1)

class Add(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x1, x2):
        return x1 + x2
    
class FlexibleShortcut(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, target_tensor):
        return x[:, 0:target_tensor.shape[1], :, :]
    
def flexible_nn_hook(module, state_dict, *_):
    modules_dict = dict(module.named_modules())

    if 'frozen' in state_dict:
        frozen = state_dict.pop('frozen')
    else:
        frozen = []


    for key in state_dict:
        module_key = '.'.join(key.split('.')[0:-1])

        if any(key.startswith(frozen_key) for frozen_key in frozen):
            requires_grad = False
        else:
            requires_grad = True
        if key.endswith('.weight'):
            if isinstance(modules_dict[key.split('.weight')[0]], torch.nn.Conv2d):
                if modules_dict[key.split('.weight')[0]].groups != 1:
                    modules_dict[key.split('.weight')[0]].groups = int(state_dict[key].shape[0])
            modules_dict[module_key].weight = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        elif key.endswith('bias'):
            modules_dict[module_key].bias = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        elif key.endswith('running_mean'):
            modules_dict[module_key].running_mean = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        elif key.endswith('running_var'):
            modules_dict[module_key].running_var = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        else:
            raise NotImplementedError
