
import torch
import torch.nn as nn
from src.models.jcgel.layers.layer import JCGConv2d
from src.models.jcgel.layers.utils import CRBatchNorm2d


class JCGConvBlock(nn.Module):
    """CGEConv2d를 사용하는 ResNet 병목 블록"""
    expansion = 4

    def __init__(self, in_channels,
                 out_channels,
                 stride,
                 c_rotations,
                 g_rotations,
                 n_flip=0,
                 temperature=0.01,
                 soft=True):
        super().__init__()

        # Main Path
        self.conv1 = JCGConv2d(in_channels, out_channels, kernel_size=1, in_c_rotations=c_rotations,
                               out_c_rotations=c_rotations, g_rotations=g_rotations, bias=False,
                               n_flip=n_flip, temperature=temperature, soft=soft)
        # self.bn1 = CRBatchNorm2d(out_channels, c_rotations, g_rotations) #si_1
        self.bn1 = CRBatchNorm2d(out_channels)
        # self.bn1 = EquivariantBatchNorm2d(out_channels)
        self.conv2 = JCGConv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1,
                               in_c_rotations=c_rotations, out_c_rotations=c_rotations,
                               g_rotations=g_rotations, bias=False,
                               n_flip=n_flip, temperature=temperature, soft=soft)
        # self.bn2 = CRBatchNorm2d(out_channels, c_rotations, g_rotations)#si_1
        self.bn2 = CRBatchNorm2d(out_channels)
        # self.bn2 = EquivariantBatchNorm2d(out_channels)
        self.conv3 = JCGConv2d(out_channels, out_channels * self.expansion, kernel_size=1, in_c_rotations=c_rotations,
                               out_c_rotations=c_rotations, g_rotations=g_rotations, bias=False,
                               n_flip=n_flip, temperature=temperature, soft=soft)
        # self.bn3 = CRBatchNorm2d(out_channels * self.expansion, c_rotations, g_rotations) #si_1
        self.bn3 = CRBatchNorm2d(out_channels * self.expansion)
        # self.bn3 = EquivariantBatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        # self.relu = NormReLU()
        # self.fiber_norm = FiberRMSNorm()
        # self.relu2 = NormReLU(out_channels)
        # self.relu3 = NormReLU(out_channels * self.expansion)
        # Shortcut Path
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                JCGConv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride,
                          in_c_rotations=c_rotations, out_c_rotations=c_rotations,
                          g_rotations=g_rotations, bias=False,
                          n_flip=n_flip, temperature=temperature, soft=soft),
                # EquivariantBatchNorm2d(out_channels * self.expansion)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out.contiguous())
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out.contiguous())
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out.contiguous())
        out += self.shortcut(x)
        out = self.relu(out)
        return out
