import torch.nn as nn
import torch
import torch.nn.functional as F
from pointnet_utils import PointNetSetAbstractionMsg,PointNetSetAbstraction,PointNetFeaturePropagation

from pointnet_lib.point_transformer_modules import PointTransformerResBlock, PointTransformerDownBlock, PointTransformerUpBlock, MLP

class PointTransformer_seg(nn.Module):
    """
    Output: 128 channels
    """
    def __init__(self, finalout_channel=6):
        super(PointTransformer_seg, self).__init__()
        cfg = {
            "channel_mult": 4,
            "div": 4,
            "pos_mlp_hidden_dim": 64,
            "attn_mlp_hidden_mult": 4,
            "pre_module": {
                "channel": 16,
                "nsample": 16
            },
            "down_module": {
                "npoint": [256, 64, 32, 16],
                "nsample": [10, 16, 16, 16],
                "attn_channel": [16, 32, 64, 64],
                "attn_num": [2, 2, 2, 2]
            },
            "up_module": {
                "attn_num": [1, 1, 1, 1]
            },
            "heads": {
                "R": [128, finalout_channel, None],
            }
        }
        k = cfg['channel_mult']
        div = cfg["div"]
        pos_mlp_hidden_dim = cfg["pos_mlp_hidden_dim"]
        attn_mlp_hidden_mult = cfg["attn_mlp_hidden_mult"]
        pre_module_channel = cfg["pre_module"]["channel"]
        pre_module_nsample = cfg["pre_module"]["nsample"]
        self.pre_module = nn.ModuleList([
            MLP(dim=1, in_channel=3, mlp=[pre_module_channel * k] * 2, use_bn=True, skip_last=False),
            PointTransformerResBlock(dim=pre_module_channel * k,
                                     div=div, pos_mlp_hidden_dim=pos_mlp_hidden_dim,
                                     attn_mlp_hidden_mult=attn_mlp_hidden_mult,
                                     num_neighbors=pre_module_nsample)
        ])
        self.down_module = nn.ModuleList()
        down_cfg = cfg["down_module"]

        last_channel = pre_module_channel
        attn_channel = down_cfg['attn_channel']
        down_sample = down_cfg['nsample']
        for i in range(len(attn_channel)):
            out_channel = attn_channel[i]
            self.down_module.append(PointTransformerDownBlock(npoint=down_cfg['npoint'][i],
                                                              nsample=down_sample[i],
                                                              in_channel=last_channel * k,
                                                              out_channel=out_channel * k,
                                                              num_attn=down_cfg['attn_num'][i],
                                                              div=div, pos_mlp_hidden_dim=pos_mlp_hidden_dim,
                                                              attn_mlp_hidden_mult=attn_mlp_hidden_mult))
            last_channel = out_channel
        up_channel = attn_channel[::-1] + [pre_module_channel]
        up_sample = down_sample[::-1]
        self.up_module = nn.ModuleList()
        up_cfg = cfg["up_module"]
        up_attn_num = up_cfg['attn_num']
        for i in range(len(attn_channel)):
            self.up_module.append(PointTransformerUpBlock(up_sample[i], up_channel[i] * k, up_channel[i + 1] * k, up_attn_num[i],
                                                          div=div, pos_mlp_hidden_dim=pos_mlp_hidden_dim,
                                                          attn_mlp_hidden_mult=attn_mlp_hidden_mult))

        self.heads = nn.ModuleDict()
        head_cfg = cfg['heads']
        for key, mlp in head_cfg.items():
            self.heads[key] = MLP(dim=1, in_channel=pre_module_channel * k, mlp=mlp[:-1], use_bn=True, skip_last=True,
                                  last_acti=mlp[-1])

    def forward(self, xyz):  # xyz: [B, 3, N]
        xyz_list, points_list = [], []
        points = self.pre_module[0](xyz)
        points = self.pre_module[1](xyz, points)
        xyz_list.append(xyz)
        points_list.append(points)

        for down in self.down_module:
            xyz, points = down(xyz, points)
            xyz_list.append(xyz)
            points_list.append(points)

        for i, up in enumerate(self.up_module):
            points = up(xyz_list[- (i + 1)], xyz_list[- (i + 2)], points, points_list[- (i + 2)])

        output = {}
        for key, head in self.heads.items():
            output[key] = head(points)

        return output['R']

class PointTransformer_cls(nn.Module):
    """
    Output: 128 channels
    """
    def __init__(self, finalout_channel=6):
        super(PointTransformer_cls, self).__init__()
        cfg = {
            "channel_mult": 4,
            "div": 4,
            "pos_mlp_hidden_dim": 64,
            "attn_mlp_hidden_mult": 4,
            "pre_module": {
                "channel": 16,
                "nsample": 16
            },
            "down_module": {
                "npoint": [256, 64, 32, 16],
                "nsample": [10, 16, 16, 16],
                "attn_channel": [16, 32, 64, 64],
                "attn_num": [2, 2, 2, 2]
            },
            "heads": {
                "R": [128, finalout_channel, None],
            }
        }
        k = cfg['channel_mult']
        div = cfg["div"]
        pos_mlp_hidden_dim = cfg["pos_mlp_hidden_dim"]
        attn_mlp_hidden_mult = cfg["attn_mlp_hidden_mult"]
        pre_module_channel = cfg["pre_module"]["channel"]
        pre_module_nsample = cfg["pre_module"]["nsample"]
        self.pre_module = nn.ModuleList([
            MLP(dim=1, in_channel=3, mlp=[pre_module_channel * k] * 2, use_bn=True, skip_last=False),
            PointTransformerResBlock(dim=pre_module_channel * k,
                                     div=div, pos_mlp_hidden_dim=pos_mlp_hidden_dim,
                                     attn_mlp_hidden_mult=attn_mlp_hidden_mult,
                                     num_neighbors=pre_module_nsample)
        ])
        self.down_module = nn.ModuleList()
        down_cfg = cfg["down_module"]

        last_channel = pre_module_channel
        attn_channel = down_cfg['attn_channel']
        down_sample = down_cfg['nsample']
        for i in range(len(attn_channel)):
            out_channel = attn_channel[i]
            self.down_module.append(PointTransformerDownBlock(npoint=down_cfg['npoint'][i],
                                                              nsample=down_sample[i],
                                                              in_channel=last_channel * k,
                                                              out_channel=out_channel * k,
                                                              num_attn=down_cfg['attn_num'][i],
                                                              div=div, pos_mlp_hidden_dim=pos_mlp_hidden_dim,
                                                              attn_mlp_hidden_mult=attn_mlp_hidden_mult))
            last_channel = out_channel

        self.pool = nn.AdaptiveMaxPool1d(output_size=1)
        self.head = nn.Sequential(
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, finalout_channel)
        )


    def forward(self, xyz):  # xyz: [B, 3, N]
        xyz_list, points_list = [], []
        points = self.pre_module[0](xyz)
        points = self.pre_module[1](xyz, points)
        xyz_list.append(xyz)
        points_list.append(points)

        for down in self.down_module:
            xyz, points = down(xyz, points)
            xyz_list.append(xyz)
            points_list.append(points)
        out = self.pool(points_list[-1]).squeeze(-1)

        return self.head(out)

class PointNet2_seg(nn.Module):
    def __init__(self, out_channel):
        super(PointNet2_seg, self).__init__()
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
        self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
        self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
        self.fp1 = PointNetFeaturePropagation(in_channel=134, mlp=[128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, out_channel, 1)

    def forward(self, xyz):
        # Set Abstraction layers
        B,C,N = xyz.shape
        l0_points = xyz
        l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([l0_xyz,l0_points],1), l1_points)
        # FC layers
        feat = F.relu(self.bn1(self.conv1(l0_points)))
        x = self.drop1(feat)
        x = self.conv2(x)
        return x


class PointNet_cls(nn.Module):
    def __init__(self, out_channel):
        super(PointNet_cls, self).__init__()
        self.feature_extracter = nn.Sequential(
            nn.Conv1d(3, 64, kernel_size=1),
            nn.LeakyReLU(),
            nn.Conv1d(64, 128, kernel_size=1),
            nn.LeakyReLU(),
            nn.Conv1d(128, 1024, kernel_size=1),
            nn.AdaptiveMaxPool1d(output_size=1)
        )

        self.mlp = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, out_channel))

    def forward(self, x):
        batch = x.shape[0]
        x = self.feature_extracter(x).view(batch, -1)
        out_data = self.mlp(x)
        return out_data


class PointNet2_cls(nn.Module):
    def __init__(self, out_channel):
        super(PointNet2_cls, self).__init__()
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)

        self.mlp = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, out_channel))

    def forward(self, xyz):
        # Set Abstraction layers
        B,C,N = xyz.shape
        l0_points = xyz
        l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)

        out_data = self.mlp(l3_points.squeeze(-1))
        return out_data


class PointNet_seg(nn.Module):
    def __init__(self, out_channel):
        super(PointNet_seg, self).__init__()
        self.f1 = nn.Sequential(
            nn.Conv1d(3, 64, kernel_size=1),
            nn.LeakyReLU()
        )
        self.f2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=1),
            nn.LeakyReLU(),
            nn.Conv1d(128, 1024, kernel_size=1),
            nn.AdaptiveMaxPool1d(output_size=1)
        )
        self.mlp = nn.Sequential(
            nn.Conv1d(1088, 512, kernel_size=1),
            nn.LeakyReLU(),
            nn.Conv1d(512, 128, kernel_size=1),
            nn.LeakyReLU(),
            nn.Conv1d(128, out_channel, kernel_size=1)
        )

    def forward(self, x):
        batch = x.shape[0]
        y = self.f1(x)
        z = self.f2(y)
        xx = torch.cat([y,z.repeat(1,1,1024)],1)
        out_data = self.mlp(xx)
        return out_data
