from torch import nn

cls_layer = {
    'vgg11_bn': 'final',
    'vgg19_bn': 'final',
    'resnet18': 'fc',
    'atmf_xt': 'fc',
}

boundary_layer = {
    'resnet18': 'conv5_x',
    'vgg11_bn': '21',
}


def dfs_freeze_fea(model, args):
    for name, child in model.named_children():
        if cls_layer[args.model] == name:
            continue
        for param in child.parameters():
            param.requires_grad = False
        dfs_freeze_fea(child, args)


def dfs_freeze_fea_rewind_cls(model, args):
    for name, child in model.named_children():
        if cls_layer[args.model] == name:
            for param in child.parameters():
                if len(param.size()) > 1:
                    nn.init.kaiming_uniform_(param)
        else:
            for param in child.parameters():
                param.requires_grad = False
        dfs_freeze_fea_rewind_cls(child, args)


def dfs_freeze_ps(model, flag, args):
    flag = flag
    for name, child in model.named_children():
        if boundary_layer[args.model] == name:
            flag = True
        for param in child.parameters():
            param.requires_grad = flag
        dfs_freeze_ps_rewind_pr(child, flag, args)


def dfs_freeze_ps_rewind_pr(model, flag, args):
    flag = flag
    for name, child in model.named_children():
        if boundary_layer[args.model] == name:
            flag = True
        for param in child.parameters():
            if len(param.size()) > 1 and flag:
                nn.init.kaiming_uniform_(param)
            param.requires_grad = flag
        dfs_freeze_ps_rewind_pr(child, flag, args)


def repalce_cls_weights(model, weights, args):
    for name, child in model.named_children():
        if cls_layer[args.model] == name:
            for param in child.parameters():
                if param.data.size() == weights.data.size():
                    param.data = weights.data
                else:
                    param.data *= 0
        repalce_cls_weights(child, weights, args)
