def pvt_v2_inner_attn_ln(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.attn.norm.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_attn_ln(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.norm1.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_mlp_ln(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.norm2.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_ln(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.norm1.requires_grad_(True)
            block.norm2.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_attn_proj(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.attn.proj.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_attn_q(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.attn.q.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_attn(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.attn.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_mlp(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.mlp.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_mlp_fc1(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.mlp.fc1.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_mlp_fc2(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.mlp.fc2.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_mlp_dwconv(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.mlp.dwconv.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_attn_proj_dwconv(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            block.attn.proj.requires_grad_(True)
            block.mlp.dwconv.requires_grad_(True)

    return backbone.head.in_features

def pvt_v2_attn_sr_norm(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            try:
                block.attn.sr.requires_grad_(True)
                block.attn.norm.requires_grad_(True)
            except AttributeError:
                print("Skipped")

    return backbone.head.in_features

def pvt_v2_attn_sr(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            try:
                block.attn.sr.requires_grad_(True)
            except AttributeError:
                print("Skipped")

    return backbone.head.in_features

def pvt_v2_attn_sr_norm_proj(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            try:
                block.attn.sr.requires_grad_(True)
                block.attn.norm.requires_grad_(True)
                block.attn.proj.requires_grad_(True)
            except AttributeError:
                print("Skipped")

    return backbone.head.in_features

def pvt_v2_attn_sr_norm_proj_dwconv(backbone):
    for stage in backbone.stages:
        for block in stage.blocks:
            try:
                block.attn.sr.requires_grad_(True)
                block.attn.norm.requires_grad_(True)
                block.attn.proj.requires_grad_(True)
                block.mlp.dwconv.requires_grad_(True)
            except AttributeError:
                print("Skipped")

    return backbone.head.in_features

def pvt_v2_classifier(backbone):

    return backbone.head.in_features

PVT_V2_UNFREEZE_FUNCTIONS = {
    "pvt_v2_attn_proj": pvt_v2_attn_proj,
    "pvt_v2_attn": pvt_v2_attn,
    "pvt_v2_attn_q": pvt_v2_attn_q,
    "pvt_v2_inner_attn_ln": pvt_v2_inner_attn_ln,
    "pvt_v2_attn_ln": pvt_v2_attn_ln,
    "pvt_v2_mlp_ln": pvt_v2_mlp_ln,
    "pvt_v2_ln": pvt_v2_ln,
    "pvt_v2_mlp": pvt_v2_mlp,
    "pvt_v2_mlp_fc1": pvt_v2_mlp_fc1,
    "pvt_v2_mlp_fc2": pvt_v2_mlp_fc2,
    "pvt_v2_mlp_dwconv": pvt_v2_mlp_dwconv,
    "pvt_v2_classifier": pvt_v2_classifier,
    "pvt_v2_attn_proj_dwconv": pvt_v2_attn_proj_dwconv,
    "pvt_v2_attn_sr_norm": pvt_v2_attn_sr_norm,
    "pvt_v2_attn_sr": pvt_v2_attn_sr,
    "pvt_v2_attn_sr_norm_proj": pvt_v2_attn_sr_norm_proj,
    "pvt_v2_attn_sr_norm_proj_dwconv": pvt_v2_attn_sr_norm_proj_dwconv
}
