from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling

import copy
from navsim.agents.gaussianfusion.transfuser_config import TransfuserConfig
from navsim.agents.gaussianfusion.transfuser_backbone import TransfuserBackbone
from navsim.agents.gaussianfusion.transfuser_features import BoundingBox2DIndex
from navsim.common.enums import StateSE2Index
from diffusers.schedulers import DDIMScheduler
from navsim.agents.gaussianfusion.modules.conditional_unet1d import (
    ConditionalUnet1D,
    SinusoidalPosEmb,
)
import torch.nn.functional as F
from navsim.agents.gaussianfusion.modules.blocks import (
    linear_relu_ln,
    bias_init_with_prob,
    gen_sineembed_for_position,
    GaussianDeformableAttn,
    GridSampleCrossBEVAttention,
)
from navsim.agents.gaussianfusion.modules.multimodal_loss import (
    LossComputer,
)
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from typing import Any, List, Dict, Optional, Union

# from navsim.agents.gaussianfusion.modules.vis_gaussian import vis_gaussian


class TransfuserModel(nn.Module):
    """Torch module for Transfuser."""

    def __init__(
        self, trajectory_sampling: TrajectorySampling, config: TransfuserConfig
    ):
        """
        Initializes TransFuser torch module.
        :param config: global config dataclass of TransFuser.
        """

        super().__init__()

        self._config = config
        self._backbone = TransfuserBackbone(config)

        bev_embed_dim = self._backbone.gaussian_init.embed_dims
        self._bev_pos_encoding = nn.Sequential(
            *linear_relu_ln(256, 1, 1, 2), nn.Linear(256, 256)
        )
        self._bev_encoding = nn.Sequential(
            *linear_relu_ln(256, 1, 1, bev_embed_dim),
        )
        self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model)

        self._trajectory_head = TrajectoryHead(
            num_poses=trajectory_sampling.num_poses,
            d_ffn=config.tf_d_ffn,
            d_model=config.tf_d_model,
            plan_anchor_path=config.plan_anchor_path,
            config=config,
        )

    def init_weights(self):
        self._backbone.init_weights()

    def forward(
        self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:

        camera_feature: torch.Tensor = features["camera_feature"]
        camera_matrix: torch.Tensor = features["camera_matrix"]
        lidar_feature: torch.Tensor = features["lidar_feature"]
        status_feature: torch.Tensor = features["status_feature"]

        gaussian_dict = self._backbone(
            camera_feature,
            lidar_feature,
            targets,
            camera_matrix,
        )

        gaussian_feature = gaussian_dict["gaussian"].features
        implicit_feature = gaussian_dict["gaussian"].im_features
        gaussian_pos = gaussian_dict["gaussian"].means

        bev_pos_embedding = self._bev_pos_encoding(gaussian_pos)
        bev_feature = self._bev_encoding(
            # gaussian_feature
            torch.cat([gaussian_feature, implicit_feature], dim=1)
        )
        status_encoding = self._status_encoding(status_feature)

        bev_feature[:, : gaussian_feature.shape[1]] = (
            bev_feature[:, : gaussian_feature.shape[1]] + bev_pos_embedding
        )
        keyval = torch.concatenate([bev_feature, status_encoding[:, None]], dim=1)

        output = {}
        output.update(
            {
                "pred_bev_occ": gaussian_dict["pred_occ"],
                "gt_bev_occ": gaussian_dict["sampled_label"],
            }
        )

        ###############################
        # vis_gaussian(gaussian_dict)
        ###############################

        trajectory = self._trajectory_head(
            keyval,
            gaussian_dict["gaussian"],
            targets=targets,
        )

        output.update(trajectory)

        return output


class DiffMotionPlanningRefinementModule(nn.Module):
    def __init__(
        self,
        embed_dims=256,
        ego_fut_ts=8,
        ego_fut_mode=20,
        if_zeroinit_reg=True,
    ):
        super(DiffMotionPlanningRefinementModule, self).__init__()
        self.embed_dims = embed_dims
        self.ego_fut_ts = ego_fut_ts
        self.ego_fut_mode = ego_fut_mode
        self.plan_cls_branch = nn.Sequential(
            *linear_relu_ln(embed_dims, 1, 2),
            nn.Linear(embed_dims, 1),
        )
        self.plan_reg_branch = nn.Sequential(
            nn.Linear(embed_dims, embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, ego_fut_ts * 3),
        )
        self.if_zeroinit_reg = False

        self.init_weight()

    def init_weight(self):
        if self.if_zeroinit_reg:
            nn.init.constant_(self.plan_reg_branch[-1].weight, 0)
            nn.init.constant_(self.plan_reg_branch[-1].bias, 0)

        bias_init = bias_init_with_prob(0.01)
        nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init)

    def forward(
        self,
        traj_feature,
    ):
        bs, ego_fut_mode, _ = traj_feature.shape

        # 6. get final prediction
        traj_feature = traj_feature.view(bs, ego_fut_mode, -1)
        plan_cls = self.plan_cls_branch(traj_feature).squeeze(-1)
        traj_delta = self.plan_reg_branch(traj_feature)
        plan_reg = traj_delta.reshape(bs, ego_fut_mode, self.ego_fut_ts, 3)

        return plan_reg, plan_cls


class CustomTransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        num_poses,
        d_model,
        d_ffn,
        config,
    ):
        super().__init__()
        self.cross_gaussian_attention = GaussianDeformableAttn(
            config.tf_d_model,
            config.tf_num_head,
            num_points=num_poses,
            config=config,
            in_bev_dims=256,
        )

        tf_decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.tf_d_model,
            nhead=config.tf_num_head,
            dim_feedforward=config.tf_d_ffn,
            dropout=config.tf_dropout,
            batch_first=True,
        )

        self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, 2)

        self.task_decoder = DiffMotionPlanningRefinementModule(
            embed_dims=config.tf_d_model,
            ego_fut_ts=num_poses,
            ego_fut_mode=20,
        )

    def forward(
        self,
        traj_feature,
        noisy_traj_points,
        gaussians,
        ego_query,
    ):
        traj_feature = self.cross_gaussian_attention(
            traj_feature, noisy_traj_points, gaussians
        )

        traj_feature = self._tf_decoder(traj_feature, ego_query)
        poses_reg, poses_cls = self.task_decoder(traj_feature)
        poses_reg[..., :2] = poses_reg[..., :2] + noisy_traj_points
        poses_reg[..., StateSE2Index.HEADING] = (
            poses_reg[..., StateSE2Index.HEADING].tanh() * np.pi
        )

        return poses_reg, poses_cls


def _get_clones(module, N):
    # FIXME: copy.deepcopy() is not defined on nn.module
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class CustomTransformerDecoder(nn.Module):
    def __init__(
        self,
        decoder_layer,
        num_layers,
        norm=None,
    ):
        super().__init__()
        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(
        self,
        traj_feature,
        noisy_traj_points,
        gaussians,
        ego_query,
    ):
        poses_reg_list = []
        poses_cls_list = []
        traj_points = noisy_traj_points
        for mod in self.layers:
            poses_reg, poses_cls = mod(
                traj_feature,
                traj_points,
                gaussians,
                ego_query,
            )
            poses_reg_list.append(poses_reg)
            poses_cls_list.append(poses_cls)
            traj_points = poses_reg[..., :2].clone().detach()
        return poses_reg_list, poses_cls_list


class TrajectoryHead(nn.Module):
    """Trajectory prediction head."""

    def __init__(
        self,
        num_poses: int,
        d_ffn: int,
        d_model: int,
        plan_anchor_path: str,
        config: TransfuserConfig,
    ):
        """
        Initializes trajectory head.
        :param num_poses: number of (x,y,θ) poses to predict
        :param d_ffn: dimensionality of feed-forward network
        :param d_model: input dimensionality
        """
        super(TrajectoryHead, self).__init__()

        self._num_poses = num_poses
        self._d_model = d_model
        self._d_ffn = d_ffn
        self.diff_loss_weight = 2.0
        self.ego_fut_mode = 20

        plan_anchor = np.load(plan_anchor_path)

        self.plan_anchor = nn.Parameter(
            torch.tensor(plan_anchor, dtype=torch.float32),
            requires_grad=False,
        )  # 20,8,2

        diff_decoder_layer = CustomTransformerDecoderLayer(
            num_poses=num_poses,
            d_model=d_model,
            d_ffn=d_ffn,
            config=config,
        )
        self.diff_decoder = CustomTransformerDecoder(diff_decoder_layer, 2)

        self.loss_computer = LossComputer(config)

    def forward(
        self,
        ego_query,
        gaussians,
        targets=None,
    ) -> Dict[str, torch.Tensor]:
        """Torch module forward pass."""
        if self.training:
            return self.forward_train(
                ego_query,
                gaussians,
                targets,
            )
        else:
            return self.forward_test(
                ego_query,
                gaussians,
            )

    def forward_train(
        self,
        ego_query,
        gaussians,
        targets=None,
    ) -> Dict[str, torch.Tensor]:
        bs = ego_query.shape[0]
        plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs, 1, 1, 1)
        traj_feature = None

        poses_reg_list, poses_cls_list = self.diff_decoder(
            traj_feature, plan_anchor, gaussians, ego_query
        )

        trajectory_loss_dict = {}
        ret_traj_loss = 0
        for idx, (poses_reg, poses_cls) in enumerate(
            zip(poses_reg_list, poses_cls_list)
        ):
            trajectory_loss = self.loss_computer(
                poses_reg, poses_cls, targets, plan_anchor
            )
            trajectory_loss_dict[f"trajectory_loss_{idx}"] = trajectory_loss
            ret_traj_loss += trajectory_loss

        mode_idx = poses_cls_list[-1].argmax(dim=-1)
        mode_idx = mode_idx[..., None, None, None].repeat(1, 1, self._num_poses, 3)
        best_reg = torch.gather(poses_reg_list[-1], 1, mode_idx).squeeze(1)
        return {
            "trajectory": best_reg,
            "trajectory_loss": ret_traj_loss,
            "trajectory_loss_dict": trajectory_loss_dict,
        }

    def forward_test(
        self,
        ego_query,
        gaussians,
    ) -> Dict[str, torch.Tensor]:
        bs = ego_query.shape[0]
        plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs, 1, 1, 1)

        traj_feature = None

        poses_reg_list, poses_cls_list = self.diff_decoder(
            traj_feature,
            plan_anchor,
            gaussians,
            ego_query,
        )
        poses_reg = poses_reg_list[-1]
        poses_cls = poses_cls_list[-1]

        mode_idx = poses_cls.argmax(dim=-1)
        mode_idx = mode_idx[..., None, None, None].repeat(1, 1, self._num_poses, 3)
        best_reg = torch.gather(poses_reg, 1, mode_idx).squeeze(1)
        return {"trajectory": best_reg}
