# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import warnings
import math

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_

from . import utils
import ipdb
st = ipdb.set_trace

def _is_power_of_2(n):
    if (not isinstance(n, int)) or (n < 0):
        raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
    return (n & (n-1) == 0) and n != 0


class MSDeformAttnPC(nn.Module):
    def __init__(
        self,
        d_model=256,
        n_levels=4,
        n_heads=8,
        n_points=8,
        n_sample=8,
    ):
        """
        Multi-Scale Deformable Attention Module
        :param d_model      hidden dimension
        :param n_levels     number of feature levels
        :param n_heads      number of attention heads
        :param n_points     number of sampling points per attention head per feature level
        :param n_sample     number of neighbours considered per sampling locations
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
        _d_per_head = d_model // n_heads
        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
        if not _is_power_of_2(_d_per_head):
            warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
                          "which is more efficient in our CUDA implementation.")

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points
        self.n_sample = n_sample

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.sampling_offsets.weight.data)
        constant_(self.sampling_offsets.bias.data, 0.)
        constant_(self.attention_weights.weight.data, 0.)
        constant_(self.attention_weights.bias.data, 0.)
        xavier_uniform_(self.value_proj.weight.data)
        constant_(self.value_proj.bias.data, 0.)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.)

    def ms_deform_attn_core_pc_pytorch(
        self,
        value,
        input_xyzs,
        sampling_locations,
        attention_weights,
        value_spatial_shapes,
        use_losses=False,
    ):
        """
        value: Point Cloud Features (N, S, M, D)
        sampling_locations: (N, Lq_, M, L, P, 3)
        input_xyzs: (N, S, L, 3)
        attention_weights: (N, Lq_, M, L, P)
        Glossary
        N_  -> Batch Size
        S_  -> Seq (points)
        M_  -> Num Heads
        D_  -> Dim
        Lq_ -> Length of query
        L_  -> num levels
        P   -> num sampling points
        """
        # st()
        N_, S_, M_, D_ = value.shape
        _, Lq_, M_, L_, P_, _ = sampling_locations.shape
        value_list = value.split(value_spatial_shapes, dim=1)
        sampling_value_list = []
        aux_loss = 0.0
        radius = 0.8  # hardcoded

        for lid_, num_points in enumerate(value_spatial_shapes):

            # N. S, L, 3 -> N, 3, S -> N*M, 3, S
            input_xyz_l = input_xyzs[:, :, lid_].permute(
                0, 2, 1).repeat(M_, 1, 1)

            # N_, S__, M_, D_ -> N_, S__, M_*D_ -> N_, M_*D_, S_ -> N_*M_, D_, S_
            value_l = value_list[lid_].flatten(2).transpose(
                1, 2).reshape(N_*M_, D_, S_)

            # N_, Lq_, M_, P_, 3 -> N_, M_, Lq_, P_, 3 -> N_*M_, Lq_, P_, 3
            sampling_locations_l = sampling_locations[:, :, :, lid_].transpose(
                1, 2).flatten(0, 1).flatten(1, 2)

            # sample values
            # N_*M_, Lq_*P_, D
            sampling_value_l_, fit_loss_l_ = utils.pc_feature_interpolation(
                input_xyz_l,
                value_l,
                sampling_locations_l,
                radius=radius,
                nsample=self.n_sample,
                use_losses=use_losses
            )
            # torch.cuda.empty_cache()
            # N_*M_, Lq_*P_, D ->  N_*M_, D, Lq_*P_ -> N_*M_, D_, Lq_, P_
            sampling_value_l_ = sampling_value_l_.permute(
                0, 2, 1).reshape(N_*M_, D_, Lq_, P_)

            sampling_value_list.append(sampling_value_l_)

            if use_losses:
            # N_*M_*Lq_, P_, 3
                sampling_locations_l = sampling_locations_l.view(N_*M_, Lq_, P_, 3).flatten(0, 1)
                rep_loss_l_ = utils.get_repulsion_loss(sampling_locations_l, radius)

                aux_loss_l_ = 2 * fit_loss_l_ + rep_loss_l_
                aux_loss += aux_loss_l_
            else:
                aux_loss = 0
        aux_loss = aux_loss / len(value_spatial_shapes)
        
        #  (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
        attention_weights = attention_weights.transpose(1, 2).reshape(
            N_*M_, 1, Lq_, L_*P_)

        output = (torch.stack(
            sampling_value_list, dim=-2).flatten(-2) * attention_weights
                  ).sum(-1).view(N_, M_*D_, Lq_)

        return output.transpose(1, 2).contiguous(), aux_loss

    def forward(
        self,
        query,
        input_flatten,
        input_xyzs,
        reference_points,
        input_spatial_shapes,
        use_losses=False,
    ):
        """
        :param query            (N, Length_{query}, C) (pc features + pos_enc)
        :param input_flatten    (N, Length_{all layer_points}, C) (pc features)
        :param input_xyzs       (N, Num_Points, L. 3)
        :refeerence points      (N, Length_{query}, 3) (during encoder)
                                (N, Length_{query}, 6) (3d bbox during decoder)
        :input_spatial_shapes   (N, L, S) (where S is num_points in level l)
        :return output          (N, Length_{query}, C)
        """
        N, Len_q, _ = query.shape  # [2, 256, 288]
        N, Len_in, _ = input_flatten.shape  # [2, 1024, 288]

        value = self.value_proj(input_flatten)
        value = value.view(
            N, Len_in, self.n_heads, self.d_model // self.n_heads)
        sampling_offsets = self.sampling_offsets(query).view(
            N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
        attention_weights = self.attention_weights(query).view(
            N, Len_q, self.n_heads, self.n_levels * self.n_points)
        attention_weights = F.softmax(
            attention_weights, -1).view(
                N, Len_q, self.n_heads, self.n_levels, self.n_points)

        # N, Len_q, n_heads, n_levels, n_points, 2
        if reference_points.shape[-1] == 3:
            sampling_locations = reference_points[:, :, None, :, None, :] \
                                 + sampling_offsets
        elif reference_points.shape[-1] == 6:
            sampling_locations = reference_points[:, :, None, :, None, :3] \
                            + (sampling_offsets / self.n_points) \
                            * reference_points[:, :, None, :, None, 3:] * 0.5
        else:
            raise ValueError(
                'Last dim of reference_points must be 3 or 6, but get {} instead.'.format(reference_points.shape[-1]))
        output, aux_loss = self.ms_deform_attn_core_pc_pytorch(
            value=value,
            input_xyzs=input_xyzs,
            value_spatial_shapes=input_spatial_shapes,
            sampling_locations=sampling_locations,
            attention_weights=attention_weights,
            use_losses=use_losses
        )

        output = self.output_proj(output)
        return output, aux_loss

