import torch
import torch.nn as nn

class PointNetReg(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp1 = nn.Sequential(
            nn.Conv1d(6, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 256, 1),
            nn.ReLU(),
            nn.Conv1d(256, 256, 1),
            nn.ReLU()
        )

        self.fc_pos = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 12)
        )

        self.fc_rot = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 16)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        feat = self.mlp1(x)
        feat = torch.max(feat, 2)[0]
        pos = self.fc_pos(feat)
        rot = self.fc_rot(feat)
        out = torch.cat([pos, rot], dim=1)
        return out