# import torch
import torch.nn as nn
import torch.nn.functional as F
from models.LAM import StackedLAM

class SegNet(nn.Module):
    def __init__(self, n_classes = 2, n_channels = 4, layer_num = 64, LAM = False, num_layers_lam=1, num_components_lam=3, BN_momentum=0.5):
        super(SegNet, self).__init__()
        
        self.LAM = LAM
        if self.LAM:
            self.stacked_lam = StackedLAM(input_dim=n_channels, num_layers=num_layers_lam, num_components=num_components_lam)
            total_input_dim = n_channels + num_layers_lam
        else:
            total_input_dim = n_channels
        #SegNet Architecture
        #Takes input of size in_chn = 3 (RGB images have 3 channels)
        #Outputs size label_chn (N # of classes)

        #ENCODING consists of 5 stages
        #Stage 1, 2 has 2 layers of Convolution + Batch Normalization + Max Pool respectively
        #Stage 3, 4, 5 has 3 layers of Convolution + Batch Normalization + Max Pool respectively

        #General Max Pool 2D for ENCODING layers
        #Pooling indices are stored for Upsampling in DECODING layers

        self.in_chn = total_input_dim
        self.out_chn = n_classes

        self.MaxEn = nn.MaxPool2d(2, stride=2, return_indices=True) 

        self.ConvEn11 = nn.Conv2d(self.in_chn, layer_num, kernel_size=3, padding=1)
        self.BNEn11 = nn.BatchNorm2d(layer_num, momentum=BN_momentum)
        self.ConvEn12 = nn.Conv2d(layer_num, layer_num, kernel_size=3, padding=1)
        self.BNEn12 = nn.BatchNorm2d(layer_num, momentum=BN_momentum)

        self.ConvEn21 = nn.Conv2d(layer_num, layer_num*2, kernel_size=3, padding=1)
        self.BNEn21 = nn.BatchNorm2d(layer_num*2, momentum=BN_momentum)
        self.ConvEn22 = nn.Conv2d(layer_num*2, layer_num*2, kernel_size=3, padding=1)
        self.BNEn22 = nn.BatchNorm2d(layer_num*2, momentum=BN_momentum)

        self.ConvEn31 = nn.Conv2d(layer_num*2, layer_num*4, kernel_size=3, padding=1)
        self.BNEn31 = nn.BatchNorm2d(layer_num*4, momentum=BN_momentum)
        self.ConvEn32 = nn.Conv2d(layer_num*4, layer_num*4, kernel_size=3, padding=1)
        self.BNEn32 = nn.BatchNorm2d(layer_num*4, momentum=BN_momentum)
        self.ConvEn33 = nn.Conv2d(layer_num*4, layer_num*4, kernel_size=3, padding=1)
        self.BNEn33 = nn.BatchNorm2d(layer_num*4, momentum=BN_momentum)

        self.ConvEn41 = nn.Conv2d(layer_num*4, layer_num*8, kernel_size=3, padding=1)
        self.BNEn41 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvEn42 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNEn42 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvEn43 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNEn43 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)

        self.ConvEn51 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNEn51 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvEn52 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNEn52 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvEn53 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNEn53 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)


        #DECODING consists of 5 stages
        #Each stage corresponds to their respective counterparts in ENCODING

        #General Max Pool 2D/Upsampling for DECODING layers
        self.MaxDe = nn.MaxUnpool2d(2, stride=2) 

        self.ConvDe53 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNDe53 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvDe52 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNDe52 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvDe51 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNDe51 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)

        self.ConvDe43 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNDe43 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvDe42 = nn.Conv2d(layer_num*8, layer_num*8, kernel_size=3, padding=1)
        self.BNDe42 = nn.BatchNorm2d(layer_num*8, momentum=BN_momentum)
        self.ConvDe41 = nn.Conv2d(layer_num*8, layer_num*4, kernel_size=3, padding=1)
        self.BNDe41 = nn.BatchNorm2d(layer_num*4, momentum=BN_momentum)

        self.ConvDe33 = nn.Conv2d(layer_num*4, layer_num*4, kernel_size=3, padding=1)
        self.BNDe33 = nn.BatchNorm2d(layer_num*4, momentum=BN_momentum)
        self.ConvDe32 = nn.Conv2d(layer_num*4, layer_num*4, kernel_size=3, padding=1)
        self.BNDe32 = nn.BatchNorm2d(layer_num*4, momentum=BN_momentum)
        self.ConvDe31 = nn.Conv2d(layer_num*4, layer_num*2, kernel_size=3, padding=1)
        self.BNDe31 = nn.BatchNorm2d(layer_num*2, momentum=BN_momentum)

        self.ConvDe22 = nn.Conv2d(layer_num*2, layer_num*2, kernel_size=3, padding=1)
        self.BNDe22 = nn.BatchNorm2d(layer_num*2, momentum=BN_momentum)
        self.ConvDe21 = nn.Conv2d(layer_num*2, layer_num, kernel_size=3, padding=1)
        self.BNDe21 = nn.BatchNorm2d(layer_num, momentum=BN_momentum)

        self.ConvDe12 = nn.Conv2d(layer_num, layer_num, kernel_size=3, padding=1)
        self.BNDe12 = nn.BatchNorm2d(layer_num, momentum=BN_momentum)
        self.ConvDe11 = nn.Conv2d(layer_num, self.out_chn, kernel_size=3, padding=1)
        self.BNDe11 = nn.BatchNorm2d(self.out_chn, momentum=BN_momentum)

    def forward(self, x):
        if self.LAM:
            x = self.stacked_lam(x)
        #ENCODE LAYERS
        #Stage 1
        x = F.relu(self.BNEn11(self.ConvEn11(x))) 
        x = F.relu(self.BNEn12(self.ConvEn12(x))) 
        x, ind1 = self.MaxEn(x)
        size1 = x.size()

        #Stage 2
        x = F.relu(self.BNEn21(self.ConvEn21(x))) 
        x = F.relu(self.BNEn22(self.ConvEn22(x))) 
        x, ind2 = self.MaxEn(x)
        size2 = x.size()

        #Stage 3
        x = F.relu(self.BNEn31(self.ConvEn31(x))) 
        x = F.relu(self.BNEn32(self.ConvEn32(x))) 
        x = F.relu(self.BNEn33(self.ConvEn33(x)))   
        x, ind3 = self.MaxEn(x)
        size3 = x.size()

        #Stage 4
        x = F.relu(self.BNEn41(self.ConvEn41(x))) 
        x = F.relu(self.BNEn42(self.ConvEn42(x))) 
        x = F.relu(self.BNEn43(self.ConvEn43(x)))   
        x, ind4 = self.MaxEn(x)
        size4 = x.size()

        #Stage 5
        x = F.relu(self.BNEn51(self.ConvEn51(x))) 
        x = F.relu(self.BNEn52(self.ConvEn52(x))) 
        x = F.relu(self.BNEn53(self.ConvEn53(x)))   
        x, ind5 = self.MaxEn(x)

        #DECODE LAYERS
        #Stage 5
        x = self.MaxDe(x, ind5, output_size=size4)
        x = F.relu(self.BNDe53(self.ConvDe53(x)))
        x = F.relu(self.BNDe52(self.ConvDe52(x)))
        x = F.relu(self.BNDe51(self.ConvDe51(x)))

        #Stage 4
        x = self.MaxDe(x, ind4, output_size=size3)
        x = F.relu(self.BNDe43(self.ConvDe43(x)))
        x = F.relu(self.BNDe42(self.ConvDe42(x)))
        x = F.relu(self.BNDe41(self.ConvDe41(x)))

        #Stage 3
        x = self.MaxDe(x, ind3, output_size=size2)
        x = F.relu(self.BNDe33(self.ConvDe33(x)))
        x = F.relu(self.BNDe32(self.ConvDe32(x)))
        x = F.relu(self.BNDe31(self.ConvDe31(x)))

        #Stage 2
        x = self.MaxDe(x, ind2, output_size=size1)
        x = F.relu(self.BNDe22(self.ConvDe22(x)))
        x = F.relu(self.BNDe21(self.ConvDe21(x)))

        #Stage 1
        x = self.MaxDe(x, ind1)
        x = F.relu(self.BNDe12(self.ConvDe12(x)))
        x = self.ConvDe11(x)

        x = F.softmax(x, dim=1)

        return x