import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# pointops from the Point Transformer V2 code
import pointops
from datetime import datetime
import numpy as np

# PointBatchNorm from the Point-M2AE code
class PointBatchNorm(nn.Module):
    """
    Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C]
    """

    def __init__(self, embed_channels):
        super().__init__()
        self.norm = nn.BatchNorm1d(embed_channels)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if input.dim() == 3:
            return self.norm(input.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
        elif input.dim() == 2:
            return self.norm(input)
        else:
            raise NotImplementedError

# The computational block of the selection network
class SellBlock(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 neighbours,
                 bias=False):
        super(SellBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.neighbours = neighbours
        bot_channels = out_channels // 4
        self.fc1 = nn.Sequential(nn.Linear(in_channels, bot_channels, bias=bias), PointBatchNorm(bot_channels), nn.ReLU(inplace=True))
        self.fc2 = nn.Sequential(nn.Linear(neighbours * (bot_channels + 1), bot_channels, bias=bias), PointBatchNorm(bot_channels),
                                 nn.ReLU(inplace=True))
        self.fc3 = nn.Sequential(nn.Linear(neighbours * (bot_channels + 1), out_channels, bias=bias), PointBatchNorm(out_channels),
                                 nn.ReLU(inplace=True))

    def forward(self, feat, reference_index, reference_dists):
        sel_out = self.fc1(feat)
        sel_out_g = torch.cat([pointops.grouping(reference_index, sel_out, feat[:1, :3]),
                               reference_dists.unsqueeze(-1)], dim=2)
        sel_out = self.fc2(sel_out_g.view(feat.shape[0], -1)) + sel_out
        sel_out_g = torch.cat([pointops.grouping(reference_index, sel_out, feat[:1, :3]),
                               reference_dists.unsqueeze(-1)], dim=2)
        return self.fc3(sel_out_g.view(feat.shape[0], -1))

# The actual loss computation. This can be decoupled from the selection computation.
# dist_vals, sel_vec, comp_ind, and max_ind are computed in NetSel and need to be passed to selloss once imp_weight is known.

def selloss(dist_vals, sel_vec, comp_ind, max_ind, imp_weight):
    emb = imp_weight
    embn = (emb - emb.min(1, keepdim=True)[0]) / (emb.max(1, keepdim=True)[0] - emb.min(1, keepdim=True)[0]) + 0.001
    imp_weight = torch.log(embn).mean(-1)
    imp_weight = (imp_weight - imp_weight.min(1, keepdim=True)[0]) / (imp_weight.max(1, keepdim=True)[0] - imp_weight.min(1, keepdim=True)[0])
    
    imp_weight = 2.4 - 2.2 * imp_weight

    n_batch, nk, n_nodes = dist_vals.shape
    bat_ind = torch.arange(n_batch, device='cuda').unsqueeze(1).unsqueeze(1)
    max_dists = (torch.quantile(dist_vals[:, -1], 0.25, dim=1, keepdim=True) * imp_weight)
    imp_weight = imp_weight.reshape(n_batch * n_nodes)[comp_ind.long()].reshape(n_batch, n_nodes, nk).transpose(2,1).contiguous()
    imp_weight[:, 0, :] = 1
    nn_gmd = dist_vals < max_dists.unsqueeze(1)
    knn_ind_map = torch.arange(nk, device='cuda').unsqueeze(1).unsqueeze(0)
    dist_knn = ((max_dists.unsqueeze(1) - dist_vals) / max_dists.unsqueeze(1) ) ** 2
    knn_sel = sel_vec.reshape(n_batch * n_nodes)[comp_ind.long()].reshape(n_batch, n_nodes, nk).transpose(2, 1)
    knn_sel = knn_sel * nn_gmd * dist_knn * imp_weight
    knn_out_sel = knn_sel[bat_ind, knn_ind_map, max_ind]

    return F.mse_loss(knn_out_sel, torch.tensor([[[1.0]] + [[0.0]] * (nk - 1)], device='cuda'))


# The selection network. Returns the selected points and selection indices,
# along with additional variables required for loss computation (used only during training).
class NetSel(nn.Module):

    def __init__(
            self,
            in_chan,
            num_heads=6,
            decrease_factor = 6,
            first=False,
            lev=0
    ):
        super(NetSel, self).__init__()
        self.dec_factor = decrease_factor
        self.red_chan = in_chan // self.dec_factor
        self.lev = lev

        self.chans = 32
        self.comp = nn.ModuleList([SellBlock(in_chan, self.chans, 32)] + [SellBlock(self.chans, self.chans, 32) for i in range(5)])
        layers = [nn.Conv1d(self.chans, 1, kernel_size=1)]
        bn = nn.BatchNorm1d(1)
        bn.bias.data.fill_(0.5)
        bn.weight.data.fill_(0.5)
        layers.append(bn)
        layers.append(nn.ReLU())
        self.sel_conv = nn.Sequential(*layers)

    def forward(self, x, num_pts=None):
        n_batch = x.shape[0]
        n_nodes = x.shape[1]

        x = x.contiguous()
        nk = 31

        comp_ind, comp_dist = pointops.knn_query(nk + 1, x.view(n_batch * n_nodes, 3).contiguous(), (torch.arange(n_batch).cuda() + 1) * n_nodes)
        dist_vals = comp_dist.view(n_batch, n_nodes, nk + 1).transpose(2, 1).contiguous()


        max_dists = torch.quantile(dist_vals[:, -1], 0.25, dim=1, keepdim=True)
        xa = torch.cat((dist_vals.transpose(2, 1), max_dists.unsqueeze(-1).expand(-1, n_nodes, -1)), 2).contiguous()


        sel_out = xa.reshape(n_batch * n_nodes, -1).contiguous()
        for l in self.comp:
            sel_out = l(sel_out, comp_ind, comp_dist)
        sel_out = sel_out.reshape(n_batch, n_nodes, self.chans).transpose(2, 1).contiguous()
        sel_vec = self.sel_conv(sel_out)

        sel_vec_sel, max_ind = torch.topk(sel_vec,
                                math.ceil(self.dec_factor) if num_pts is None else num_pts)
        loss = 0
        if self.training:
            loss = (dist_vals, sel_vec, comp_ind, max_ind)

        max_indl = (max_ind.squeeze(1) + n_nodes * torch.arange(n_batch).unsqueeze(-1).cuda()).view(-1)
        out = x.view(n_nodes * n_batch, 3)[max_indl].view(n_batch, max_ind.shape[2], 3)

        return out, loss, max_ind


