import torch
from torch import nn
import torch.nn.functional as F


class ShapeGFEncoder(nn.Module):
    def __init__(self, model_cfg, runtime_cfg):
        super(ShapeGFEncoder, self).__init__()
        self.zdim = model_cfg["zdim"]
        self.use_deterministic_encoder = model_cfg["use_deterministic_encoder"]
        self.input_dim = model_cfg.input_dim
        self.conv1 = nn.Conv1d(self.input_dim, 128, 1)
        self.conv2 = nn.Conv1d(128, 128, 1)
        self.conv3 = nn.Conv1d(128, 256, 1)
        self.conv4 = nn.Conv1d(256, 512, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
        self.bn4 = nn.BatchNorm1d(512)

        if self.use_deterministic_encoder:
            self.fc1 = nn.Linear(512, 256)
            self.fc2 = nn.Linear(256, 128)
            self.fc_bn1 = nn.BatchNorm1d(256)
            self.fc_bn2 = nn.BatchNorm1d(128)
            self.fc3 = nn.Linear(128, self.zdim)
        else:
            raise NotImplementedError

        runtime_cfg["num_point_feature"] = 128

    def forward(self, points_input):
        """Encode each point cloud into a latent code
        Args:
            framed_point_clouds [8B, N, 3]: shape point clouds in 8 coordinate systems.

        Returns:
            latent_codes [8B, L]: latent vector of each point cloud
            sigma (int): noise level
        """

        # [8B, 3, N]
        #x = batch_dict['framed_point_clouds'].transpose(1, 2)
        x = points_input.transpose(1, 2) 
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        # [8B, 512, N]
        x = self.bn4(self.conv4(x))
        # [8B, 512]
        x = x.max(dim=-1)[0]
        
        if self.use_deterministic_encoder:
            ms = F.relu(self.fc_bn1(self.fc1(x)))
            ms = F.relu(self.fc_bn2(self.fc2(ms)))
            ms = self.fc3(ms)
            #batch_dict['latent_codes'] = ms
            #batch_dict['sigmas'] = 0.0
            return ms 
        else:
            raise NotImplementedError
        #return batch_dict

