import math
import torch.nn as nn


class Adapter(nn.Module):

    def __init__(self, dim, ratio=4):

        super().__init__()

        self.op = nn.Sequential(*[
            nn.Conv2d(dim, dim//ratio, kernel_size=1, stride=1),
            nn.GELU(),
            nn.Conv2d(dim//ratio, dim, kernel_size=1, stride=1)
        ])

    def forward(self, x):

        return x + self.op(x)
    
class AdapterMLP(nn.Module):

    def __init__(self, dim, ratio=4):

        super().__init__()

        self.op = nn.Sequential(*[
            nn.Linear(dim, dim//ratio),
            nn.GELU(),
            nn.Linear(dim//ratio, dim),
            nn.LayerNorm(dim)
        ])

    def forward(self, x):
        
        return x + self.op(x)

def convnext_conv(backbone):

    adapted_modules = []
    num_stages = len(backbone.stages)
    for stage_num in range(num_stages):
        stage = backbone.stages[stage_num]
        try:
            adapted_modules.append((stage.downsample[1], stage.downsample[1].out_channels, Adapter))
        except:
            print("Skipping")
        num_blocks = len(stage.blocks)
        for block_num in range(num_blocks):
            block = stage.blocks[block_num]
            adapted_modules.append((block.conv_dw, block.conv_dw.out_channels, Adapter))

    return adapted_modules

CONVNEXT_MODULE_ITERATORS = {
    "convnext_conv": convnext_conv,
}