import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp.autocast_mode import autocast
from .utils import get_rotation_matrix


def linear_relu_ln(embed_dims, in_loops, out_loops, input_dims=None):
    if input_dims is None:
        input_dims = embed_dims
    layers = []
    for _ in range(out_loops):
        for _ in range(in_loops):
            layers.append(nn.Linear(input_dims, embed_dims))
            layers.append(nn.ReLU(inplace=True))
            input_dims = embed_dims
        layers.append(nn.LayerNorm(embed_dims))
    return layers


def gen_sineembed_for_position(pos_tensor, hidden_dim=256):
    """Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/"""
    half_hidden_dim = hidden_dim // 2
    scale = 2 * math.pi
    dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device)
    dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim)
    x_embed = pos_tensor[..., 0] * scale
    y_embed = pos_tensor[..., 1] * scale
    pos_x = x_embed[..., None] / dim_t
    pos_y = y_embed[..., None] / dim_t
    pos_x = torch.stack(
        (pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1
    ).flatten(-2)
    pos_y = torch.stack(
        (pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1
    ).flatten(-2)
    pos = torch.cat((pos_y, pos_x), dim=-1)
    return pos


def bias_init_with_prob(prior_prob):
    """initialize conv/fc bias value according to giving probablity."""
    bias_init = float(-np.log((1 - prior_prob) / prior_prob))
    return bias_init


class GaussianDeformableAttn(nn.Module):
    def __init__(
        self,
        embed_dims,
        num_heads,
        num_levels=1,
        in_bev_dims=64,
        num_points=8,
    ):
        super(GaussianDeformableAttn, self).__init__()
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.num_levels = num_levels
        self.num_points = num_points
        # self.anchor_embed = nn.Sequential(nn.Linear(embed_dims, num_points))
        # self.attention_weights = nn.Linear(embed_dims, num_points)
        # self.output_proj = nn.Linear(embed_dims, embed_dims)
        # self.dropout = nn.Dropout(0.1)

        self.value_proj = nn.Sequential(
            nn.Linear(in_bev_dims, embed_dims),
        )

        self.anchor_pos_encoder = nn.Sequential(
            *linear_relu_ln(embed_dims, 1, 1, 512),
            nn.Linear(embed_dims, embed_dims),
        )
        self.gs_pos_encoder = nn.Sequential(
            *linear_relu_ln(embed_dims, 1, 1, embed_dims),
            nn.Linear(embed_dims, embed_dims),
        )
        cross_attention_layer = nn.TransformerDecoderLayer(
            d_model=embed_dims,
            nhead=num_heads,
            dim_feedforward=embed_dims * 4,
            dropout=0.1,
            batch_first=True,
        )

        self.anchor_query = nn.Embedding(1, embed_dims)
        self.anchor_cross_attention = nn.TransformerDecoder(cross_attention_layer, 1)
        # self.init_weight()

    def init_weight(self):

        nn.init.constant_(self.attention_weights.weight, 0)
        nn.init.constant_(self.attention_weights.bias, 0)

        nn.init.xavier_uniform_(self.output_proj.weight)
        nn.init.constant_(self.output_proj.bias, 0)

    def prepare_gaussian_args(self, gaussians):
        means = gaussians.means  # b, g, 2
        scales = gaussians.scales  # b, g, 2
        rotations = gaussians.rotations  # b, g, 2
        opacities = gaussians.semantics  # b, g, c
        origi_opa = gaussians.opacities  # b, g, 1
        # means_3d = torch.cat([means, torch.zeros_like(means[..., :1])], dim=-1)
        # scales_3d = torch.cat([scales, torch.ones_like(scales[..., :1])], dim=-1)
        if origi_opa.numel() == 0:
            origi_opa = torch.ones_like(opacities[..., :1], requires_grad=False)

        bs, g, _ = means.shape
        S = torch.zeros(bs, g, 3, 3, dtype=means.dtype, device=means.device)
        S[..., 0, 0] = scales[..., 0]
        S[..., 1, 1] = scales[..., 1]
        S[..., 2, 2] = 1
        R = get_rotation_matrix(rotations)  # b, g, 3, 3
        M = torch.matmul(S, R)
        Cov = torch.matmul(M.transpose(-1, -2), M).float()
        CovInv = Cov.cpu().inverse().cuda()  # b, g, 3, 3
        return means, origi_opa, opacities, scales, CovInv[..., :2, :2]

    def forward(self, traj_feature, traj_points, gaussians, topk=8, **kwargs):
        B, A, P, _ = traj_points.shape
        sampled_xy = traj_points.flatten(1, 2).clone()  # [B, A*P, 2]
        gaussian_features_exp = gaussians.features  # [B, G, D]
        gaussian_features_imp = gaussians.im_features  # [B, G, D]
        value = self.value_proj(gaussian_features_exp)  # [B, G, D]

        means, origi_opa, opacities, scales, CovInv = self.prepare_gaussian_args(
            gaussians
        )  # means: [B, G, 2]

        # 计算每个 sampled_xy 到每个 gaussian 的欧几里得距离
        diff = sampled_xy[:, :, None, :] - means[:, None, :, :]  # [B, A*P, G, 2]
        dist_sq = (diff**2).sum(-1)  # [B, A*P, G]

        # 获取距离最近的前32个gaussian的索引
        _, indices = torch.topk(dist_sq, topk, dim=-1, largest=False)  # [B, A*P, 32]

        # Gather对应的高斯参数
        def batched_gather(x, idx):
            # x: [B, G, ...], idx: [B, N, K]
            B, N, K = idx.shape
            G = x.shape[1]
            # Flatten batch and gather dims for indexing
            idx_flat = idx.reshape(B, -1)  # [B, N*K]

            # Batch offset for advanced indexing
            batch_offsets = torch.arange(B, device=x.device).view(B, 1) * G  # [B, 1]
            idx_flat = idx_flat + batch_offsets  # [B, N*K]

            x_flat = x.reshape(B * G, *x.shape[2:])  # [B*G, ...]
            gathered = x_flat[idx_flat.reshape(-1)]  # [(B*N*K), ...]
            return gathered.view(B, N, K, *x.shape[2:])  # [B, N, K, ...]

        traj_pos_embed = gen_sineembed_for_position(traj_points, hidden_dim=64)
        traj_pos_embed = traj_pos_embed.flatten(-2)
        traj_pos_feature = self.anchor_pos_encoder(traj_pos_embed).view(
            B * A, 1, -1
        )  # [B, A, D]
        # traj_query = self.anchor_query.weight[None, ...].repeat(B, 1, 1)
        # traj_query = traj_query.view(B * A, 1, -1)
        traj_query = self.anchor_query.weight[None, ...].expand(A * B, -1, -1)

        value_sel = batched_gather(value, indices).view(
            B * A, self.num_points * topk, -1
        )
        means_sel = batched_gather(means, indices).view(
            B * A, self.num_points * topk, -1
        )
        gs_pos_embed = gen_sineembed_for_position(means_sel, hidden_dim=256)
        gs_pos_feature = self.gs_pos_encoder(gs_pos_embed)
        out = self.anchor_cross_attention(
            traj_query + traj_pos_feature, value_sel + gs_pos_feature
        )
        return out.view(B, A, -1)
