# import torch
import torch.nn as nn
from models.LAM import StackedLAM

class FCN8s(nn.Module):
    def __init__(self, n_classes=2, n_channels=4, layer_num = 32, LAM = False, num_layers_lam=1, num_components_lam=3):
        super().__init__()
        def conv_block(in_c, out_c, n=2):
            layers = []
            for _ in range(n):
                layers.append(nn.Conv2d(in_c, out_c, 3, padding=1))
                layers.append(nn.BatchNorm2d(out_c))
                layers.append(nn.ReLU(inplace=True))
                in_c = out_c
            layers.append(nn.MaxPool2d(2, 2))
            return nn.Sequential(*layers)
        
        self.LAM = LAM
        if self.LAM:
            self.stacked_lam = StackedLAM(input_dim=n_channels, num_layers=num_layers_lam, num_components=num_components_lam)
            total_input_dim = n_channels + num_layers_lam
        else:
            total_input_dim = n_channels
            
        # VGG-lite backbone
        self.stage1 = conv_block(total_input_dim, layer_num, 6)
        self.stage2 = conv_block(layer_num, layer_num*2, 7)
        self.stage3 = conv_block(layer_num*2, layer_num*4, 10)
        self.stage4 = conv_block(layer_num*4, layer_num*8, 10)
        self.stage5 = conv_block(layer_num*8, layer_num*8, 10)

        
        self.score_pool3 = nn.Conv2d(layer_num*4, n_classes, kernel_size=1)
        self.score_pool4 = nn.Conv2d(layer_num*8, n_classes, kernel_size=1)
        self.score = nn.Conv2d(layer_num*8, n_classes, kernel_size=1)

        
        self.upscore2_1 = nn.ConvTranspose2d(n_classes, n_classes, 4, stride=2, padding=1, bias=False)
        self.upscore2_2 = nn.ConvTranspose2d(n_classes, n_classes, 4, stride=2, padding=1, bias=False)
        self.upscore8 = nn.ConvTranspose2d(n_classes, n_classes, 16, stride=8, padding=4, bias=False)

    def forward(self, x):
        if self.LAM:
            x = self.stacked_lam(x)
            
        x1 = self.stage1(x)  # 1/2
        x2 = self.stage2(x1) # 1/4
        x3 = self.stage3(x2) # 1/8
        x4 = self.stage4(x3) # 1/16
        x5 = self.stage5(x4) # 1/32

        s5 = self.score(x5)
        s5_up = self.upscore2_1(s5)

        s4 = self.score_pool4(x4)
        fuse4 = s5_up + s4

        fuse4_up = self.upscore2_2(fuse4)
        s3 = self.score_pool3(x3)
        fuse3 = fuse4_up + s3

        out = self.upscore8(fuse3)
        return out

# if __name__ == "__main__":
#     model = FCN8s(n_classes=2, n_channels=4, layer_num = 64)

#     input = torch.autograd.Variable(torch.randn(1, 4, 512, 512))
#     output = model(input)
#     print(output.shape)