import torch
from torch import nn

class ColorBallEncoder(nn.Module):
    def __init__(self, input_len, output_len) -> None:
        super().__init__()
        self.layer = nn.Sequential(nn.Linear(input_len, 2 * output_len), nn.ReLU(True), nn.Linear(2 * output_len, output_len))

    def forward(self, x):
        return self.layer(x)
    
class PointNet(nn.Module):
    def __init__(self, feat_len, input_len=3):
        super(PointNet, self).__init__()

        self.conv1 = nn.Conv1d(input_len, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)

        self.mlp1 = nn.Linear(1024, feat_len)
        self.bn6 = nn.BatchNorm1d(feat_len)

    """
        Input: B x N x 3 (B x P x N x 3)
        Output: B x F (B x P x F)
    """

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = torch.relu(self.bn4(self.conv4(x)))
        x = torch.relu(self.bn5(self.conv5(x)))

        x = x.max(dim=-1)[0]

        x = torch.relu(self.bn6(self.mlp1(x)))
        return x

class LabelEncoder(nn.Module):
    def __init__(self, label_len, output_len) -> None:
        super().__init__()
        self.layer = nn.Sequential(nn.Linear(label_len, 2 * output_len), 
                                   nn.ReLU(True), 
                                   nn.Linear(2 * output_len, output_len),
                                   nn.ReLU(True))

    def forward(self, x):
        return self.layer(x)
