import torch
from torch_geometric.nn import fps, nearest
from torch_geometric.nn import global_max_pool, global_mean_pool
from torch_geometric.nn import PointNetConv
from ip.utils.common_utils import PositionalEncoder
from torch_geometric.nn import MLP
from typing import Optional, Union
from torch_geometric.typing import OptTensor, PairOptTensor, Adj, PairTensor, Tensor


class SceneEncoder(torch.nn.Module):
    def __init__(self, num_freqs, embd_dim=256, num_layers=2, scale=1.):
        super().__init__()
        self.embd_dim = embd_dim

        self.num_layers = num_layers

        self.sa1_module = SAModule(0.125, [3, 128, 128, 128],
                                   global_nn_dims=[128, 256, 256],
                                   num_freqs=num_freqs,
                                   scale=1 / 0.05,
                                   norm=None)
        self.sa2_module = SAModule(0.0625, [256 + 3, 512, 512, 512],
                                   global_nn_dims=[512, 512, embd_dim],
                                   num_freqs=num_freqs,
                                   scale=1 / 0.2,
                                   norm=None,
                                   plain_last=True)

    def forward(self, x, pos, batch):
        sa0_out = (x, pos, batch)
        sa1_out = self.sa1_module(*sa0_out)

        if self.num_layers == 1:
            return sa1_out

        sa2_out = self.sa2_module(*sa1_out)
        return sa2_out


class PointNetConvPE(PointNetConv):
    def __init__(self, nn_dims, global_nn_dims=None, aggr='mean', num_freqs=4, cat_pos=False,
                 scale=1., plain_last=False, residual_mls=False, norm=None):
        self.scale = scale

        # Adjust nn_dims to include positional encoding.
        nn_dims[0] += 3 * (2 * num_freqs)
        if not residual_mls:
            nn = MLP(nn_dims, norm=None, act=torch.nn.GELU(approximate='tanh'), plain_last=False)
        else:
            nn = ResidualMLP(nn_dims, norm=norm, act=torch.nn.GELU(approximate='tanh'), plain_last=False)

        if cat_pos and global_nn_dims is not None:
            global_nn_dims[0] += 3 * (2 * num_freqs + 1)

        if not residual_mls:
            global_nn = None if global_nn_dims is None else \
                MLP(global_nn_dims, norm=None, act=torch.nn.GELU(approximate='tanh'),
                    plain_last=plain_last)
        else:
            global_nn = None if global_nn_dims is None else \
                ResidualMLP(global_nn_dims, norm=norm, act=torch.nn.GELU(approximate='tanh'),
                            plain_last=plain_last)

        self.cat_pos = cat_pos
        super().__init__(nn, global_nn=global_nn, add_self_loops=False, aggr=aggr)
        self.pe = PositionalEncoder(3, num_freqs)


    def message(self, x_j: Optional[Tensor], pos_i: Tensor,
                pos_j: Tensor) -> Tensor:
        msg = self.pe((pos_j - pos_i) * self.scale)
        if x_j is not None:
            msg = torch.cat([x_j, msg], dim=1)
        if self.local_nn is not None:
            msg = self.local_nn(msg)
        return msg

    def forward(self, x: Union[OptTensor, PairOptTensor],
                pos: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:

        if not isinstance(x, tuple):
            x: PairOptTensor = (x, None)

        if isinstance(pos, Tensor):
            pos: PairTensor = (pos, pos)

        # propagate_type: (x: PairOptTensor, pos: PairTensor)
        out = self.propagate(edge_index, x=x, pos=pos, size=None)

        if self.global_nn is not None:
            if self.cat_pos:
                out = torch.cat([out, self.pe(pos[1])], dim=1)
            out = self.global_nn(out)

        return out


class ResidualMLP(torch.nn.Module):
    def __init__(self, nn_dims, act=None, norm=None, plain_last=False):
        super().__init__()
        self.act = act
        if norm is not None:
            self.norms = [torch.nn.LayerNorm(nn_dims[i]) for i in range(1, len(nn_dims) - 1)]
        self.plain_last = plain_last
        self.linear_layers = torch.nn.ModuleList(
            [torch.nn.Linear(nn_dims[i], nn_dims[i + 1]) for i in range(len(nn_dims) - 1)])

    def forward(self, x):
        for i, layer in enumerate(self.linear_layers):
            if i == 0:
                x = layer(x)
            else:
                x = x + layer(x)
            if i != len(self.linear_layers) - 1:
                if self.norm is not None:
                    x = self.norms[i](x)
                if self.act is not None:
                    x = self.act(x)

        if not self.plain_last:
            x = self.act(x)

        return x


class SAModule(torch.nn.Module):
    def __init__(self, ratio, nn_dims, global_nn_dims=None, num_freqs=4, aggr='mean', cat_pos=False,
                 scale=1., plain_last=False, norm=None):
        super().__init__()
        self.cat_pos = cat_pos
        self.ratio = ratio
        self.conv = PointNetConvPE(nn_dims, global_nn_dims, aggr=aggr, num_freqs=num_freqs, cat_pos=cat_pos,
                                   scale=scale, plain_last=plain_last, norm=norm)

    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio)
        row = nearest(pos, pos[idx], batch, batch[idx])
        col = torch.arange(0, pos.shape[0], dtype=torch.long, device=pos.device)

        edge_index = torch.stack([col, row], dim=0)
        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return [x, pos, batch]


class GlobalSAModule(torch.nn.Module):
    def __init__(self, nn_dims, global_pool='mean', num_freqs=0):
        super().__init__()
        # Adjust nn_dims to include positional encoding.
        nn_dims[0] += 3 * (2 * num_freqs + 1)
        self.nn = MLP(nn_dims, plain_last=True, act=torch.nn.GELU(approximate='tanh'), norm=None)
        self.global_pool = global_pool
        self.pe = PositionalEncoder(3, num_freqs)

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, self.pe(pos)], dim=1))
        if self.global_pool == 'max':
            x = global_max_pool(x, batch)
        else:
            x = global_mean_pool(x, batch)
        pos = pos.new_zeros((x.size(0), 3))
        batch = torch.arange(x.size(0), device=batch.device)
        return [x, pos, batch]
