from .r2gconv import *
import torch
import torch.nn as nn

Act = nn.SiLU

"""
-R2GCBA -> -R2GConv + BN + Act
"""


class R2LiftGCBA(nn.Module):
    def __init__(self, c1, c2, k, s, p):
        super(R2LiftGCBA, self).__init__()
        self.gcv = R2Lift(c1, c2, k, s, p)
        self.bn3d = nn.BatchNorm3d(c2)
        self.act = Act()

    def forward(self, x):
        return self.act(self.bn3d(self.gcv(x)))


class ER2GCBA(nn.Module):
    def __init__(self, c1, c2, k, s, p):
        super(ER2GCBA, self).__init__()
        self.gcv = ER2GConv(c1, c2, k, s, p)
        self.bn3d = nn.BatchNorm3d(c2)
        self.act = Act()

    def forward(self, x):
        return self.act(self.bn3d(self.gcv(x)))


class PR2GCBA(nn.Module):
    def __init__(self, c1, c2):
        super(PR2GCBA, self).__init__()
        self.gcv = PR2GConv(c1, c2)
        self.bn3d = nn.BatchNorm3d(c2)
        self.act = Act()

    def forward(self, x):
        return self.act(self.bn3d(self.gcv(x)))


class Bottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True, k=3, e=0.5):
        super(Bottleneck, self).__init__()
        c_ = int(c2 * e)  # reduce channels
        self.gcv1 = ER2GCBA(c1, c_, k, 1, 1)
        self.gcv2 = ER2GCBA(c_, c2, k, 1, 1)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.gcv2(self.gcv1(x)) if \
            self.add else self.gcv2(self.gcv1(x))


class R2NetBlock(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=False):
        e = 0.5 if n in [1, 2, 3] else 0.25
        super(R2NetBlock, self).__init__()
        c_ = int(c2 * e)  # reduce channels
        self.c1 = c1
        self.c2 = c2
        self.gcv1 = PR2GCBA(c1, c_ * (n + 1))
        self.gcv2 = PR2GCBA(c_ * (n + 1), c2)
        self.blocks = nn.ModuleList(
            [Bottleneck(c_, c_, shortcut) for _ in range(n)]
        )
        self.C = n

    def forward(self, x):
        x = self.gcv1(x)
        # channel split
        x = torch.chunk(x, chunks=self.C+1, dim=1)
        ys = []
        #
        for i in range(self.C + 1):
            if i == 0:
                ys.append(x[i])
            else:
                ys.append(self.blocks[i - 1](ys[-1] + x[i]))
        # channel concat
        x = torch.cat(ys, dim=1)
        return self.gcv2(x)


class MaxPooling(nn.Module):
    def __init__(self, k, s, p, g=g_order):
        super(MaxPooling, self).__init__()
        self.g = g
        self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=p)

    def forward(self, x):
        x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])
        x = self.m(x)
        return x.view(x.shape[0], -1, self.g, x.shape[-2], x.shape[-1])


class R2SPPF(nn.Module):
    def __init__(self, c1, c2, k=5):
        super(R2SPPF, self).__init__()
        self.block = MaxPooling(k, 1, k // 2)
        self.gcv = PR2GCBA(c1, c2)

    def forward(self, x):
        # channel split
        x = torch.chunk(x, chunks=4, dim=1)
        ys = []
        # R2SPPF: only split the input channel for 4 parts
        for i in range(4):
            if i == 0:
                ys.append(x[i])
            elif i == 1:
                ys.append(self.block(x[i]))
            else:
                ys.append(self.block(ys[-1] + x[i]))
        # channel concat
        x = torch.cat(ys, dim=1)
        return self.gcv(x)


class TransferBlock(nn.Module):
    def __init__(self, c1, c2, g=g_order):
        super(TransferBlock, self).__init__()
        # reduce the input channel by 4 x
        self.gcv = ER2GCBA(c1, c2 // g, 3, 1, 1)
        self.bn3d = nn.BatchNorm3d(c2 // g)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.act(self.bn3d(self.gcv(x)))
        # reshape: Cn -> Z2
        return x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])

