import torch
import torch.nn as nn
from src.models.jcgel.layers.layer import JCGConv2d

class JCG_Encoder3C_Disent(nn.Module):
    def __init__(self, config):
        super(JCG_Encoder3C_Disent, self).__init__()
        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.num_sampling = config.num_sampling

        self.c_rot = config.c_rot
        self.g_rot = config.g_rot
        self.n_flip = config.n_flip
        self.temp = config.temperature
        self.normalize = config.normalization
        self.soft =config.soft

        modules.append(JCGConv2d(in_channels=3, out_channels=32,
                                 kernel_size=4, stride=2, padding=0,
                                 in_c_rotations=1, out_c_rotations=self.c_rot,
                                 g_rotations=self.g_rot, is_lifting=True,
                                 temperature=self.temp,
                                 normalization=True, soft=True))
        modules.append(nn.ReLU(True))
        modules.append(JCGConv2d(in_channels=32, out_channels=32,
                                 kernel_size=4, stride=2, padding=0,
                                 in_c_rotations=self.c_rot, out_c_rotations=self.c_rot,
                                 g_rotations=self.g_rot,
                                 temperature=self.temp,
                                 normalization=True, soft=True))
        modules.append(nn.ReLU(True))
        modules.append(JCGConv2d(in_channels=32, out_channels=64,
                                 kernel_size=4, stride=2, padding=0,
                                 in_c_rotations=self.c_rot, out_c_rotations=self.c_rot,
                                 g_rotations=self.g_rot,
                                 temperature=self.temp,
                                 normalization=True, soft=True))
        modules.append(nn.ReLU(True))
        modules.append(JCGConv2d(in_channels=64, out_channels=64,
                                 kernel_size=4, stride=2, padding=0,
                                 in_c_rotations=self.c_rot, out_c_rotations=self.c_rot,
                                 g_rotations=self.g_rot,
                                 temperature=self.temp,
                                 normalization=True, soft=True))
        modules.append(nn.ReLU(True))
        self.hidden_layers = nn.ModuleList(modules)

        self.dense = nn.Linear(config.dense_dim[0]*self.c_rot*self.g_rot*(self.n_flip+1), config.dense_dim[1])
        self.mu = nn.Linear(config.dense_dim[1], self.latent_dim)
        self.logvar = nn.Linear(config.dense_dim[1], self.latent_dim)


    def forward(self, input):
        all_hidden_states = ()

        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)

        output = self.dense(
            output.contiguous().view(output.size(0), -1)
        )  # 4-D tensor: [Batch, *] --> 2-D tensor: [Batch, latent dim]
        # pdb.set_trace()
        mu = self.mu(output)  # [Batch, latent dim]
        logvar = self.logvar(output)  # [Batch, latent dim]

        z = self.reparameterization(mu, logvar)
        outputs = (
                      z,
                      mu,
                      logvar,
                  ) + (
                      all_hidden_states,
                  )  # (z, mu, logvar, (outputs))
        return outputs

    def reparameterization(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        if self.num_sampling == 1:
            eps = torch.randn_like(logvar)
            z = mu + std * eps
            return z
        else:
            batch, dim = std.size()
            eps = torch.randn(size=(self.num_sampling, batch, dim))
            z = mu + std * eps

            return z
