
import torch.nn as nn
from src.models.jcgel.layers.layer import JCGConv2d
from src.models.jcgel.layers.utils import ColorGeometricPooling


class JCG_Encoder3C_Imbalance(nn.Module):
    def __init__(self, config):
        super(JCG_Encoder3C_Imbalance, self).__init__()
        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.num_sampling = config.num_sampling

        self.c_rot = config.c_rot
        self.g_rot = config.g_rot
        self.n_flip = config.n_flip
        self.temp = config.temperature
        self.normalize = config.normalization
        self.soft =config.soft
        # self.gspace = gspaces.FlipRot2dOnR2(N=4)
        # Design Encoder Factor-VAE ref
        modules.append(JCGConv2d(
                                in_channels=3,
                                out_channels=20,
                                kernel_size=3,
                                in_c_rotations=1,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=True,
                                stride=1,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # modules.append(BatchNorm6d(C=20, c=self.c_rot, g=4, mode="C"))
        modules.append(JCGConv2d(
                                in_channels=20,
                                out_channels=20,
                                kernel_size=3,
                                in_c_rotations=self.c_rot,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=False,
                                stride=1,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # modules.append(BatchNorm6d(C=20, c=self.c_rot, g=4, mode="C"))
        modules.append(JCGConv2d(
                                in_channels=20,
                                out_channels=20,
                                kernel_size=3,
                                in_c_rotations=self.c_rot,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=False,
                                stride=2,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # modules.append(BatchNorm6d(C=20, c=self.c_rot, g=4, mode="C"))
        modules.append(JCGConv2d(
                                in_channels=20,
                                out_channels=20,
                                kernel_size=3,
                                in_c_rotations=self.c_rot,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=False,
                                stride=1,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # modules.append(BatchNorm6d(C=20, c=self.c_rot, g=4, mode="C"))
        modules.append(JCGConv2d(
                                in_channels=20,
                                out_channels=20,
                                kernel_size=3,
                                in_c_rotations=self.c_rot,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=False,
                                stride=2,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # modules.append(BatchNorm6d(C=20, c=self.c_rot, g=4, mode="C"))
        modules.append(JCGConv2d(
                                in_channels=20,
                                out_channels=20,
                                kernel_size=3,
                                in_c_rotations=self.c_rot,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=False,
                                stride=2,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # modules.append(BatchNorm6d(C=20, c=self.c_rot, g=4, mode="C"))
        modules.append(JCGConv2d(
                                in_channels=20,
                                out_channels=64,
                                kernel_size=4,
                                in_c_rotations=self.c_rot,
                                out_c_rotations=self.c_rot,
                                g_rotations=self.g_rot,
                                is_lifting=False,
                                stride=2,
                                padding=0,
                                temperature=self.temp,
                                normalization=True,
                                soft=True
                                ))
        # if config.dataset
        modules.append(ColorGeometricPooling(pool_color = True, pool_geom = True))
        self.hidden_layers = nn.ModuleList(modules)


    def forward(self, input):
        all_hidden_states = ()

        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)

        output = output.contiguous().view(output.size(0), -1) # [B, C]
        outputs = (output,) + (all_hidden_states,)
        return outputs