import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


net_stride = {'down1':2, 'down2': 4, 'down3': 8, 'down4': 16, 'down5': 32 }
feat_dim = {'down1': 64, 'down2': 256, 'down3': 512, 'down4':1024, 'down5':2048 }


class DoubleConv(nn.Module):
    '''(conv - bn - relu) * 2'''
    def __init__(self, in_channels, out_channels, mid_channels = None):
        super(DoubleConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.doubleconv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(mid_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        ) 
            
    def forward(self, X):
        X = self.doubleconv(X)
        return X


class Up(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.doubleconv = DoubleConv(in_channels, out_channels, mid_channels)
        
    def forward(self, X1, X2):
        X1 = self.up(X1)
        diffY = torch.tensor([X2.size()[2] - X1.size()[2]])
        diffX = torch.tensor([X2.size()[3] - X1.size()[3]])
        # just incase:
        X1 = F.pad(X1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        X = torch.cat([X2, X1], dim=1)
        X = self.doubleconv(X)
        return X


class ResnetUpSample(nn.Module):
    def __init__(self, pretrained):
        super(ResnetUpSample, self).__init__()
        self.pretrained = pretrained
        self.net = models.resnet50(pretrained = self.pretrained)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.down1 = nn.Sequential(
            self.net.conv1,
            self.net.bn1,
            self.net.relu,
        )
        self.down2 = nn.Sequential(
            self.net.maxpool,
            self.net.layer1,
        ) 
        self.down3 = self.net.layer2
        self.down4 = self.net.layer3
        self.down5 = self.net.layer4

    def forward(self, X):
        X1 = self.down1(X)
        X2 = self.down2(X1)
        X3 = self.down3(X2)
        X4 = self.down4(X3)
        X5 = self.down5(X4)
        
        return self.upsample(X5)
    

class ResUNetPrePre(nn.Module):
    def __init__(self, pretrained):
        super(ResUNetPre, self).__init__()
        self.pretrained = pretrained
        self.net = models.resnet50(pretrained = self.pretrained)
        self.down1 = nn.Sequential(
            self.net.conv1,
            self.net.bn1,
            self.net.relu,
        )
        self.down2 = nn.Sequential(
            self.net.maxpool,
            self.net.layer1,
        ) 
        self.down3 = self.net.layer2
        self.down4 = self.net.layer3
        self.down5 = self.net.layer4
        self.dc1 = DoubleConv(2048, 1024)
        
        self.up1 = Up(2048, 1024, 512)
        
    def forward(self, X):
        X1 = self.down1(X)
        X2 = self.down2(X1)
        X3 = self.down3(X2)
        X4 = self.down4(X3)
        X5 = self.down5(X4)
        X5 = self.dc1(X5)
        X4 = self.up1(X5, X4)
        return X4


class ResUNetPre(nn.Module):
    def __init__(self, pretrained):
        super(ResUNetPre, self).__init__()
        self.pretrained = pretrained
        self.net = models.resnet50(pretrained = self.pretrained)
        self.down1 = nn.Sequential(
            self.net.conv1,
            self.net.bn1,
            self.net.relu,
        )
        self.down2 = nn.Sequential(
            self.net.maxpool,
            self.net.layer1,
        ) 
        self.down3 = self.net.layer2
        self.down4 = self.net.layer3
        self.down5 = self.net.layer4
        self.dc1 = DoubleConv(2048, 1024)
        
        self.up1 = Up(2048, 1024, 512)
        self.up2 = Up(1024, 512, 256)
        
    def forward(self, X):
        X1 = self.down1(X)
        X2 = self.down2(X1)
        X3 = self.down3(X2)
        X4 = self.down4(X3)
        X5 = self.down5(X4)
        X5 = self.dc1(X5)
        X4 = self.up1(X5, X4)
        X3 = self.up2(X4, X3)
        return X3



class ResUNet(nn.Module):
    def __init__(self, pretrained):
        super(ResUNet, self).__init__()
        self.pretrained = pretrained
        self.net = models.resnet50(pretrained = self.pretrained)
        self.down1 = nn.Sequential(
            self.net.conv1,
            self.net.bn1,
            self.net.relu,
        )
        self.down2 = nn.Sequential(
            self.net.maxpool,
            self.net.layer1,
        ) 
        self.down3 = self.net.layer2
        self.down4 = self.net.layer3
        self.down5 = self.net.layer4
        self.dc1 = DoubleConv(2048, 1024)
        
        self.up1 = Up(2048, 1024, 512)
        self.up2 = Up(1024, 512, 256)
        self.up3 = Up(512, 256, 64)
        self.up4 = Up(128, 128, 128)
        
    def forward(self, X):
        X1 = self.down1(X)
        X2 = self.down2(X1)
        X3 = self.down3(X2)
        X4 = self.down4(X3)
        X5 = self.down5(X4)
        X5 = self.dc1(X5)
        X4 = self.up1(X5, X4)
        X3 = self.up2(X4, X3)
        X2 = self.up3(X3, X2)
        X1 = self.up4(X2, X1)
        return X1
        
        
        