def convnext_ln(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        num_blocks = len(stage.blocks)
        try:
            stage.downsample[0].requires_grad_(True)
        except:
            print("Skipping")
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.norm.requires_grad_(True)

    return backbone.head.fc.in_features


def convnext_conv_dw(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.conv_dw.requires_grad_(True)
            block.norm.requires_grad_(True)

    return backbone.head.fc.in_features

def convnext_conv_dw_ln(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.norm.requires_grad_(True)
            block.conv_dw.requires_grad_(True)

    return backbone.head.fc.in_features

def convnext_mlp(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.mlp.requires_grad_(True)

    return backbone.head.fc.in_features

def convnext_mlp_up(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.mlp.fc1.requires_grad_(True)

    return backbone.head.fc.in_features

def convnext_mlp_down(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.mlp.fc2.requires_grad_(True)

    return backbone.head.fc.in_features

def convnext_downsample(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        try:
            stage.downsample[1].requires_grad_(True)
        except:
            print("Skipping")

    return backbone.head.fc.in_features

def convnext_downsample_conv_dw(backbone):
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        try:
            stage.downsample.requires_grad_(True)
        except:
            print("Skipping")
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.conv_dw.requires_grad_(True)
            block.norm.requires_grad_(True)

    return backbone.head.fc.in_features

def convnext_downsample_conv_dw_stem(backbone):
    num_stages = len(backbone.stages)
    backbone.stem.requires_grad_(True)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        try:
            stage.downsample.requires_grad_(True)
        except:
            print("Skipping")
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            block.conv_dw.requires_grad_(True)
            block.norm.requires_grad_(True)

    return backbone.head.fc.in_features


def convnext_classifier(backbone):

    return backbone.head.fc.in_features


CONVNEXT_UNFREEZE_FUNCTIONS = {
    "convnext_ln": convnext_ln,
    "convnext_conv_dw": convnext_conv_dw,
    "convnext_conv_dw_ln": convnext_conv_dw_ln,
    "convnext_mlp": convnext_mlp,
    "convnext_mlp_up": convnext_mlp_up,
    "convnext_mlp_down": convnext_mlp_down,
    "convnext_downsample": convnext_downsample,
    "convnext_classifier": convnext_classifier,
    "convnext_downsample_conv_dw": convnext_downsample_conv_dw,
    "convnext_downsample_conv_dw_stem": convnext_downsample_conv_dw_stem
}
