import torch
from torch import nn
import torch.nn.functional as F
import math
from knn_cuda import KNN


def conv(ni, no, activation=True, bn_zero=False):
    """
    Returns a sequential layer object.
    activation=True: determines whether a ReLU layer is included (1x1 conv, relu, bn ) or False: (1x1 conv, bn)
    bn_zero=False: determines whether the batch normalization layer is zero initialized.
    """

    layers = [nn.Conv1d(ni, no, kernel_size=1)]
    if activation:
        layers.append(nn.LeakyReLU())
    bn = nn.BatchNorm1d(no)
    if bn_zero:
        bn.weight.data.fill_(0.)
    layers.append(bn)
    return nn.Sequential(*layers)


def gather_knn(feat_map, nn_ind, k, subtract=True):
    """
    For every batch individually the features of the k nearest neighbors of every point inside the respective batch are stacked behind the features of respective point in columns.
    Shape of returned array: (#features x (k + 1), #points, #batches)
    feat_map: (#batches, #features/channels, #points)
    nn_ind: (#batches, #(k + 1), #points)
    k: #ofNearestNeighbors that are considered
    substract=True: determines whether the feature values of the point (every column belongs to one point) are substracted from the feature values of all its nearest neighbors in the columns.
    """
    n_batch = feat_map.shape[0]                                                              # number of batches
    bat_ind = torch.arange(n_batch, device='cuda').unsqueeze(1).unsqueeze(1)                 # returns a tensor of indices counting up from zero to n_batch for a cuda device, shape: [nbatch, 1, 1]
    n_feats = feat_map.shape[1]                                                              # number of features/channels   
    feat_ind = torch.arange(n_feats, device='cuda').repeat(k + 1).unsqueeze(1).unsqueeze(0)  # returns a tensor of indices counting up from zero to (n_feats x (k+1)), shape: [1, (n_feats x (k+1)), 1]
    feat_map_nod = feat_map[bat_ind, feat_ind, nn_ind.repeat_interleave(n_feats, dim=1)]     # shape of returned array: (#features x k + 1, #numberOfPoints, #numberOfBatches)
    feat_map_nod = feat_map_nod[:, n_feats:, :] - feat_map.repeat((1, k, 1)) if subtract \
        else feat_map_nod[:, n_feats:, :]                                                    #if substract is true the feature values of the respective node are substracted
#                                                                                             from the feature values of the nearest neighbors in the respective column

    # cat the features of every node back to the start of the array
    return torch.cat((feat_map, feat_map_nod), 1)


class ResPointBlock(nn.Module):
    def __init__(
        
            self,
            in_chan,
            out_chan,
            n_knn=10
    ):
        super(ResPointBlock, self).__init__()
        
        self.k = n_knn
        self.knn = KNN(k=n_knn + 1)
        self.same_chan = in_chan == out_chan
        self.reduce_conv = conv(in_chan, in_chan // 4)                             # returns sequential layer(1x1 conv, relu, bn) with one fourth of the input channel as output channels
        self.conv = conv(in_chan // 4 * (self.k + 1), in_chan)                     # returns sequential layer(1x1 conv, relu, bn) with all already convoluted channels of nn and node it self back to # of in_chan
        self.expand_conv = conv(in_chan, out_chan, activation=False, bn_zero=True) # returns sequential layer(1x1 conv, bn) no relu layer returned and weights of bn layer are zero initialized
        self.skip_conv = conv(in_chan, out_chan, activation=False)                 # no relu layer returned layer(1x1 conv, bn)

    def forward(self, x, ind=None, act=True, subtract=True):
        """
            The standard block following the ResNet structure with neither up- nor down-sampling
        """

        if ind is None:                            # if we do not know the nearest neighbor indices calculate them
            _, ind = self.knn(x, x)    #ind: (#batches, #(k + 1), #points)

        out = gather_knn(self.reduce_conv(x), ind, self.k, subtract=subtract)  #out: (#features x (k + 1), #points, #batches)
        out = self.expand_conv(self.conv(out))
        skip = x if self.same_chan else self.skip_conv(x)                      #if in_chang == out_chan we add a skip path through the ResPointBlock.
        
        # (1x1 conv that reduces feature count by factor 4, NN are gathered in columns, performs convolution on gathered columns and reduces feature number to original in_chan,
        # input is added and corrected for number of features by skip_conv if necessary)
        
        return F.leaky_relu(out + skip) if act else out + skip                 #if act=True (default) we add a ReLU layer to the sum of our output and our skip


class ResPointStrideBlock(nn.Module):

    def __init__(
            self,
            in_chan,
            out_chan,
            n_knn=10,
            decrease_factor = 6
    ):
        super(ResPointStrideBlock, self).__init__()
        self.k = n_knn
        self.dec_factor = decrease_factor
        self.red_chan = in_chan // self.dec_factor
        self.same_chan = in_chan == out_chan

        self.reduce_conv = conv(in_chan, in_chan // 4)
        self.conv = conv(in_chan // 4 * (self.k + 1), in_chan)
        self.expand_conv = conv(in_chan, out_chan, activation=False, bn_zero=True)

        # calculates scores for each point
        self.sel_reduce_conv1 = conv(in_chan, in_chan // 4)
        self.sel_conv1 = conv(in_chan // 4 * (self.dec_factor + 3), in_chan)
        self.sel_expand_conv1 = conv(in_chan, out_chan)
        self.sel_reduce_conv2 = conv(out_chan, out_chan // 4)
        self.sel_conv2 = conv(out_chan // 4 * (self.dec_factor + 3), out_chan)
        self.sel_expand_conv2 = conv(out_chan, out_chan)
        layers = [nn.Conv1d(in_chan, 1, kernel_size=1)]
        bn = nn.BatchNorm1d(1)
        bn.bias.data.fill_(0.5)
        bn.weight.data.fill_(0.5)
        layers.append(bn)
        layers.append(nn.LeakyReLU())
        self.sel_conv = nn.Sequential(*layers)

        self.skip_conv = conv(in_chan, out_chan, activation=False)

    def forward(self, x, ind):
        """
            The ResNet block with a stride greater than one. The stride is determined with self.dec_factor.
            after the selection the selection loss is computed by summing the rows and the columns.
            This loss is passed to the upper functions along with the selected output.
        """
        n_batch = x.shape[0]
        bat_ind = torch.arange(n_batch, device='cuda').unsqueeze(1).unsqueeze(1)
        n_nodes = x.shape[2]

        # selects the best points
        sel_out = x
        sel_out = self.sel_reduce_conv1(sel_out)
        sel_out = self.sel_conv1(gather_knn(sel_out, ind[:, :(self.dec_factor + 3), :], k=(self.dec_factor + 2), subtract=False))
        sel_out = self.sel_expand_conv1(sel_out)
        sel_out = self.sel_reduce_conv2(sel_out)
        sel_out = self.sel_conv2(gather_knn(sel_out, ind[:, :(self.dec_factor + 3), :], k=(self.dec_factor + 2), subtract=False))
        sel_out = self.sel_expand_conv2(sel_out)
        sel_vec = self.sel_conv(sel_out)

        _, max_ind = torch.topk(sel_vec, math.ceil(n_nodes / self.dec_factor))
        knn_ind_map = torch.arange(self.dec_factor, device='cuda').unsqueeze(1).unsqueeze(0)
        knn_sel = gather_knn(sel_vec, ind[:, :self.dec_factor, :], k=self.dec_factor - 1, subtract=False)
        knn_out_sel = knn_sel[bat_ind, knn_ind_map, max_ind]

        # compute the loss
        col_sel_sum = knn_out_sel.sum(dim=-1, keepdim=True)
        row_sel_sum = knn_sel.sum(dim=-2, keepdim=True)
        loss_part1 = F.mse_loss(col_sel_sum, torch.ones_like(col_sel_sum) * torch.tensor([[[math.ceil(n_nodes / self.dec_factor)]] + [[0]] * (self.dec_factor - 1)], device='cuda'))
        loss_part2 = F.mse_loss(row_sel_sum, torch.ones_like(row_sel_sum))
        loss = loss_part1 + loss_part2 * math.ceil(n_nodes / self.dec_factor)

        # compute and select the output
        out = gather_knn(self.reduce_conv(x), ind, self.k)
        out_unsel = self.expand_conv(self.conv(out))
        feat_ind = torch.arange(out_unsel.shape[1], device='cuda').unsqueeze(1).unsqueeze(0)
        out = out_unsel[bat_ind, feat_ind, max_ind]

        skip = x if self.same_chan else self.skip_conv(x)
        skip = gather_knn(skip, ind, self.k, subtract=False)
        skip = torch.stack(skip.split(out.shape[1], dim=1), dim=-1)
        feat_skip_ind = torch.arange(out.shape[1], device='cuda').unsqueeze(1).unsqueeze(0)
        skip = skip.mean(-1)[bat_ind, feat_skip_ind, max_ind]

        return F.leaky_relu(out + skip), loss, max_ind, out_unsel


class ResPointUnstrideBlock(nn.Module):
    def __init__(
            self,
            in_chan,
            n_knn=10,
            increase_factor=6
    ):
        super(ResPointUnstrideBlock, self).__init__()
        self.k = n_knn
        self.inc_factor = increase_factor

        self.conv = conv(in_chan * 2, in_chan)
        self.expand_conv = conv(in_chan, in_chan, activation=False, bn_zero=True)

    def forward(self, x, ind):
        """
            The transposed convolution block.
            The number of upsampled points is determined with self.inc_factor.
        """

        n_batch = x.shape[0]
        bat_ind = torch.arange(n_batch, device='cuda').unsqueeze(1).unsqueeze(1)
        n_feats = x.shape[1]
        feat_ind = torch.arange(n_feats, device='cuda').unsqueeze(1).unsqueeze(0)
        n_points = x.shape[2]
        pts_ind = torch.arange(n_points, device='cuda').repeat_interleave(self.inc_factor)

        kfeat_flat_ind = torch.arange(self.inc_factor, device='cuda').repeat([1, 1, n_points])

        nn_ind_flat = ind[bat_ind, kfeat_flat_ind, pts_ind]

        upper_transpose_t = x[bat_ind, feat_ind, pts_ind]
        lower_transpose_t = x[bat_ind, feat_ind, nn_ind_flat] - upper_transpose_t
        out = torch.cat((upper_transpose_t, lower_transpose_t), 1)

        out = self.conv(out)
        out = self.expand_conv(out)

        skip = upper_transpose_t + (lower_transpose_t / 2)

        return F.leaky_relu(out + skip)


class ResPointLevel(nn.Module):

    def __init__(self, in_chan, out_chan, n_knn=10, decrease_factor=6):
       super().__init__()
       self.k = n_knn
       self.knn = KNN(k=n_knn + 1)
       self.res1_block = ResPointBlock(in_chan, out_chan, n_knn=self.k)
       self.res2_block = ResPointBlock(out_chan, out_chan, n_knn=self.k)
       self.res_stride_block = ResPointStrideBlock(in_chan=out_chan, out_chan=out_chan, n_knn=self.k, decrease_factor=decrease_factor)

    def forward(self, x, ind=None):
        """
            A feature hierarchie "level" of the encoder comprised of two normal ResNet blocks and a block with stride
        """
        if ind is None:
            _, ind = self.knn(x, x)

        out = self.res1_block(x, ind)
        out = self.res2_block(out, ind)

        out, loss, max_ind, sel_vec = self.res_stride_block(out, ind)

        return out, loss, max_ind, sel_vec

class ResPointUpLevel(nn.Module):

    def __init__(self, in_chan, out_chan, n_knn=10, increase_factor=6):
        super().__init__()
        self.k = n_knn
        self.knn = KNN(k=n_knn + 1)
        self.res1_block = ResPointBlock(in_chan, out_chan, n_knn=self.k)
        self.res2_block = ResPointBlock(out_chan, out_chan, n_knn=self.k)
        self.res_stride_block = ResPointUnstrideBlock(out_chan, n_knn=self.k, increase_factor=increase_factor)

    def forward(self, x):
        """
            A feature hierarchie "level" of the decoder comprised of two normal ResNet blocks and a block employing a transposed convolution
        """
        _, ind = self.knn(x, x)
        out = self.res1_block(x, ind)
        out = self.res2_block(out, ind)

        out = self.res_stride_block(out, ind)

        return out


class PointAutoencNet(nn.Module):

    def __init__(self, n_knn=10, fdfu=[6, 6, 6, 6], cc=9):

        super().__init__()
        self.k = n_knn
        self.knn = KNN(k=n_knn + 1)

        self.level_one = ResPointLevel(3 * (self.k + 1), 32, n_knn=self.k, decrease_factor=fdfu[0])
        self.level_two = ResPointLevel(32, 64, n_knn=self.k, decrease_factor=fdfu[1])
        self.res_block = ResPointBlock(64, cc, n_knn=self.k)

        self.level_up_one = ResPointUpLevel(cc, 64, n_knn=self.k, increase_factor=fdfu[2])

        self.level_up_two = ResPointUpLevel(64, 32, n_knn=self.k, increase_factor=fdfu[3])

        self.res1_last_block = ResPointBlock(32, 32, n_knn=self.k)
        self.res2_last_block = ResPointBlock(32, 32, n_knn=self.k)
        self.res3_last_block = ResPointBlock(32, 3, n_knn=self.k)

    def forward(self, x):
        """
            The complete network based on two decoder and encoder levels with one ResNet block additionally processing
            the last selected output and three res blocks which predict the final points.
        """
        _, ind = self.knn(x, x)
        out = gather_knn(x, ind, self.k)

        out, loss1, max_ind1, sel_vec1 = self.level_one(out, ind)
        out, loss2, max_ind2, sel_vec2 = self.level_two(out)

        code = self.res_block(out)

        out = self.level_up_one(code)
        out = self.level_up_two(out)

        _, ind = self.knn(out, out)
        out = self.res1_last_block(out, ind)
        out = self.res2_last_block(out, ind)
        out = self.res3_last_block(out, act=False)

        return out, loss1, loss2, max_ind1, max_ind2, sel_vec1, sel_vec2

