from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
from .blocks import GaussianDeformableAttn, bias_init_with_prob, linear_relu_ln
from typing import Dict


class DiffMotionPlanningRefinementModule(nn.Module):
    def __init__(
        self,
        embed_dims=256,
        ego_fut_ts=8,
        ego_fut_mode=20,
    ):
        super(DiffMotionPlanningRefinementModule, self).__init__()
        self.embed_dims = embed_dims
        self.ego_fut_ts = ego_fut_ts
        self.ego_fut_mode = ego_fut_mode

        self.plan_cls_branch = nn.Sequential(
            *linear_relu_ln(embed_dims, 1, 2),
            nn.Linear(embed_dims, 1),
        )
        self.plan_reg_branch = nn.Sequential(
            nn.Linear(embed_dims, embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, ego_fut_ts * 2),
        )
        self.if_zeroinit_reg = False

        self.init_weight()

    def init_weight(self):
        if self.if_zeroinit_reg:
            nn.init.constant_(self.plan_reg_branch[-1].weight, 0)
            nn.init.constant_(self.plan_reg_branch[-1].bias, 0)

        bias_init = bias_init_with_prob(0.01)
        nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init)

    def forward(self, traj_feature):
        bs, ego_fut_mode, _ = traj_feature.shape

        traj_feature = traj_feature.view(bs, ego_fut_mode, -1)
        plan_cls = self.plan_cls_branch(traj_feature).squeeze(-1)
        traj_delta = self.plan_reg_branch(traj_feature)
        plan_reg = traj_delta.view(bs, ego_fut_mode, self.ego_fut_ts, 2)
        plan_cls = self.plan_cls_branch(traj_feature).view(bs, ego_fut_mode)

        return plan_reg, plan_cls


class CustomTransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model,
        d_ffn,
        num_poses,
        in_bev_dims=128,
    ):
        super().__init__()

        self.cross_gaussian_attention = GaussianDeformableAttn(
            embed_dims=d_model,
            num_heads=8,
            num_points=num_poses,
            in_bev_dims=in_bev_dims,
        )
        
        self.tp_encoder = nn.Sequential(
            *linear_relu_ln(d_model, 1, 2, 2),
            nn.Linear(d_model, d_model),
        )
        self.tp_dropout = nn.Dropout(p=0.25)

        tf_decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=d_ffn,
            dropout=0.0,
            batch_first=True,
        )

        self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, 3)

        self.task_decoder = DiffMotionPlanningRefinementModule(
            embed_dims=d_model,
            ego_fut_ts=num_poses,
            ego_fut_mode=20,
        )

    def forward(
        self, traj_feature, noisy_traj_points, gaussians, ego_query, target_point
    ):

        traj_feature = self.cross_gaussian_attention(
            traj_feature, noisy_traj_points, gaussians
        )

        tp_feature = self.tp_dropout(self.tp_encoder(target_point))
        
        traj_feature = self._tf_decoder(traj_feature, torch.cat([ego_query, tp_feature[:, None]], dim=1))
        poses_reg, poses_cls = self.task_decoder(
            traj_feature
        )  # bs,20,8,3; bs,20
        poses_reg[..., :2] = poses_reg[..., :2] + noisy_traj_points

        return poses_reg, poses_cls


def _get_clones(module, N):
    # FIXME: copy.deepcopy() is not defined on nn.module
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class CustomTransformerDecoder(nn.Module):
    def __init__(
        self,
        decoder_layer,
        num_layers,
        norm=None,
    ):
        super().__init__()
        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(
        self, traj_feature, noisy_traj_points, gaussians, ego_query, target_point
    ):
        poses_reg_list = []
        poses_cls_list = []
        traj_points = noisy_traj_points
        for mod in self.layers:
            poses_reg, poses_cls = mod(
                traj_feature, traj_points, gaussians, ego_query, target_point
            )
            poses_reg_list.append(poses_reg)
            poses_cls_list.append(poses_cls)
            traj_points = poses_reg[..., :2].clone().detach()
        return poses_reg_list, poses_cls_list


class TrajectoryHead(nn.Module):
    """Trajectory prediction head."""

    def __init__(
        self,
        num_poses: int,
        d_ffn: int,
        d_model: int,
        plan_anchor_path: str,
    ):
        """
        Initializes trajectory head.
        :param num_poses: number of (x,y,θ) poses to predict
        :param d_ffn: dimensionality of feed-forward network
        :param d_model: input dimensionality
        """
        super(TrajectoryHead, self).__init__()

        self._num_poses = num_poses
        self._d_model = d_model
        self._d_ffn = d_ffn
        self.diff_loss_weight = 2.0
        self.ego_fut_mode = 20

        plan_anchor = np.load(plan_anchor_path)

        self.plan_anchor = nn.Parameter(
            torch.tensor(plan_anchor, dtype=torch.float32),
            requires_grad=False,
        )  # 20,8,2

        diff_decoder_layer = CustomTransformerDecoderLayer(
            num_poses=num_poses,
            d_model=d_model,
            d_ffn=d_ffn,
        )
        self.diff_decoder = CustomTransformerDecoder(diff_decoder_layer, 2)

    def forward(
        self,
        ego_query,
        gaussians,
        target_points=None,
    ) -> Dict[str, torch.Tensor]:
        """Torch module forward pass."""
        bs = ego_query.shape[0]

        plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs, 1, 1, 1)

        traj_feature = None

        poses_reg_list, poses_cls_list = self.diff_decoder(
            traj_feature, plan_anchor, gaussians, ego_query, target_points
        )

        poses_reg = poses_reg_list[-1]
        poses_cls = poses_cls_list[-1]

        mode_idx = poses_cls.argmax(dim=-1)
        mode_idx = mode_idx[..., None, None, None].repeat(1, 1, self._num_poses, 2)
        best_reg = torch.gather(poses_reg, 1, mode_idx).squeeze(1)

        return poses_reg_list, poses_cls_list, best_reg
