import torch.nn as nn
import torch.nn.functional as F
from .modules import ConvBlk, UpConvBlk, ConvBlkCond, UpConvBlkCond

class UNet(nn.Module):
    def __init__(self, ic, oc):
        super().__init__()
        nc_mid = 1024
        nc_enc = [ic, 64, 128, 256, 512, nc_mid]
        nc_dec = [nc_mid, 512, 256, 128, 64, 32]
        self.e0 = nn.ModuleList()
        self.d0 = nn.ModuleList()
        for i in range(1, len(nc_enc)):
            self.e0.append(ConvBlk(nc_enc[i-1], nc_enc[i]))
        for i in range(0, len(nc_enc)-1):
            self.d0.append(UpConvBlk(nc_dec[i], nc_enc[-i-1], nc_dec[i+1]))
        self.middle = ConvBlk(nc_mid, nc_mid)
        self.final = nn.Conv2d(nc_dec[-1], oc, 1, 1, 0)
    
    def encoder(self, x):
        features = []
        o = x
        for cur_enc in self.e0:
            o0 = cur_enc(o)
            o = F.max_pool2d(o0, 2, 2)
            features.append(o0)
        return o, features
    
    def decoder(self, x, f):
        o = x
        for i, cur_dec in enumerate(self.d0):
            o = cur_dec(o, f[-i-1])
        return o
    
    def forward(self, x):
        o, f = self.encoder(x)
        o = self.middle(o)
        o = self.decoder(o, f)
        o = self.final(o)
        return o


class UNetCond(nn.Module):
    def __init__(self, ic, oc):
        super().__init__()
        nc_mid = 1024
        nc_enc = [ic, 64, 128, 256, 512, nc_mid]
        nc_dec = [nc_mid, 512, 256, 128, 64, 32]
        self.e0 = nn.ModuleList()
        self.d0 = nn.ModuleList()
        for i in range(1, len(nc_enc)):
            self.e0.append(ConvBlkCond(nc_enc[i-1], nc_enc[i]))
        for i in range(0, len(nc_enc)-1):
            self.d0.append(UpConvBlkCond(nc_dec[i], nc_enc[-i-1], nc_dec[i+1]))
        self.middle = ConvBlk(nc_mid, nc_mid)
        self.final = nn.Conv2d(nc_dec[-1], oc, 1, 1, 0)
    
    def encoder(self, x, cond):
        features = []
        o = x
        for cur_enc in self.e0:
            o0 = cur_enc(o, cond)
            o = F.max_pool2d(o0, 2, 2)
            features.append(o0)
        return o, features
    
    def decoder(self, x, f, cond):
        o = x
        for i, cur_dec in enumerate(self.d0):
            o = cur_dec(o, f[-i-1], cond)
        return o
    
    def forward(self, x, cond):
        o, f = self.encoder(x, cond)
        o = self.middle(o)
        o = self.decoder(o, f, cond)
        o = self.final(o)
        return o


class DoubleBranchUNet(nn.Module):
    def __init__(self, ic, oc):
        super().__init__()
        nc_mid = 1024
        nc_enc = [ic, 64, 128, 256, 512, nc_mid]
        nc_dec = [nc_mid, 512, 256, 128, 64, 32]
        self.e0 = nn.ModuleList()
        self.d0 = nn.ModuleList()
        for i in range(1, len(nc_enc)):
            self.e0.append(ConvBlk(nc_enc[i-1], nc_enc[i]))
        for i in range(0, len(nc_enc)-1):
            self.d0.append(UpConvBlk(nc_dec[i], nc_enc[-i-1], nc_dec[i+1]))
        self.middle = ConvBlk(nc_mid, nc_mid)
        self.final = nn.Conv2d(nc_dec[-1], oc, 1, 1, 0)
    
    def encoder(self, x):
        features = []
        o = x
        for cur_enc in self.e0:
            o0 = cur_enc(o)
            o = F.max_pool2d(o0, 2, 2)
            features.append(o0)
        return o, features
    
    def decoder(self, x, f):
        o = x
        for i, cur_dec in enumerate(self.d0):
            o = cur_dec(o, f[-i-1])
        return o
    
    def forward(self, x):
        o, f = self.encoder(x)
        o = self.middle(o)
        o = self.decoder(o, f)
        o = self.final(o)
        return o