"""
Set Pooling Networks to Classify Point Clouds
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class SetPoolingNet(nn.Module):
    # spnet(A) = rho(max_phi(A)) = rho(max [phi(a) for a in A])
    def __init__(self, phi: nn.Module, rho: nn.Module, pool='max'):
        super().__init__() 
        self.phi = phi
        self.rho = rho
        self.pool = pool
        
    def forward(self, x):
        # Shape of x should be [batch, cloudsize, pt_dim]
        # Compute pointwise features, then maxpool to make global feature
        
        x = self.phi.forward(x)   # [batch, cloudsize, phi_out_dim]
        
        if self.pool=='sum':
            # Warning: Sum pooling has bugs, produces NaNs during training
            x = torch.sum(x, dim=1, keepdim=False) # [batch, phi_out_dim]
        elif self.pool=='ave':
            x = torch.mean(x, dim=1, keepdim=False) # same
        elif self.pool=='max':
            x, _ = torch.max(x, dim=1, keepdim=False) # same
            
            
        return self.rho.forward(x)  # shape = [batch, rho_out_dim]
        
class SimpleNet(nn.Module):
    def __init__(self, shape=[2,5,2]):
        super().__init__()
        self.fc1 = nn.Linear(shape[0],shape[1])
        self.fc2 = nn.Linear(shape[1], shape[2])
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class SimpleSetPoolingNet(SetPoolingNet):
    def __init__(self, shape = [1,2,3,2,1], pool='max'):
        self.shape = shape

        phi = SimpleNet(shape[:3]) # -> pool -> rho
        rho = SimpleNet(shape[2:])
        super().__init__(phi, rho, pool)


if __name__ == '__main__':  # Unit Test models
    # Set random seeds
    torch.manual_seed(42)
    
    # Layer sizes    
    shape = [2, 15, 15, 15, 3]
    
    # Instantiate network architectures
    phi = SimpleNet(shape[:3])
    rho = SimpleNet(shape[2:])
    spnet = SetPoolingNet(phi, rho, 'max')
    N_params_spnet = sum(p.numel() for p in spnet.parameters() if p.requires_grad)
    print(f'The network has {N_params_spnet} parameters')
    
    # An even simpler network
    spnet2 = SimpleSetPoolingNet(shape)
    
    # Test data
    N_clds, N_pts = 4, 500
    N_feats = N_pts * shape[0]
    d = torch.arange(N_clds * N_feats, dtype=torch.float)
    d = d.view(-1,N_pts, shape[0])
    
    ### Evaluate
    # Print d as is, and with a different view
    print('Input Tensor d of shape', list(d.shape), ':\n')
    print( d, '\n')
    print('Input Tensor d.view(-1,N_feats) of shape', list(d.view(-1,N_feats).shape), ':\n')
    print( d.view(-1,N_feats), '\n')
    
    print('Output of phi on d is shape', list(phi(d).shape), ':\n')
    
    print('Output of spnet(d) of shape', list(spnet(d).shape), ':\n')
    print( spnet(d), '\n')
    
    print('Output of spnet2(d) of shape', list(spnet2(d).shape), ':\n')
    print( spnet2(d), '\n')  