from typing import Dict

import numpy as np
import torch
import torch.nn as nn
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling

from navsim.agents.gaussianfusion.transfuser_backbone import TransfuserBackbone
from navsim.agents.gaussianfusion.transfuser_config import TransfuserConfig
from navsim.agents.gaussianfusion.transfuser_features import BoundingBox2DIndex
from navsim.common.enums import StateSE2Index
from navsim.agents.gaussianfusion.modules.blocks import linear_relu_ln

# from navsim.agents.gaussianfusion.modules.vis_gaussian import vis_gaussian


class TransfuserModel(nn.Module):
    """Torch module for Transfuser."""

    def __init__(
        self, trajectory_sampling: TrajectorySampling, config: TransfuserConfig
    ):
        """
        Initializes TransFuser torch module.
        :param trajectory_sampling: trajectory sampling specification.
        :param config: global config dataclass of TransFuser.
        """

        super().__init__()

        self._query_splits = [
            1,
            # config.num_bounding_boxes,
        ]

        self._config = config
        self._backbone = TransfuserBackbone(config)


        self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)


        bev_embed_dim = self._backbone.gaussian_init.embed_dims
        self._bev_pos_encoding = nn.Sequential(
            *linear_relu_ln(256, 1, 1, 2), nn.Linear(256, 256)
        )
        self._bev_encoding = nn.Sequential(
            *linear_relu_ln(256, 1, 1, bev_embed_dim),
        )
        self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model)

        tf_decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.tf_d_model,
            nhead=config.tf_num_head,
            dim_feedforward=config.tf_d_ffn,
            dropout=config.tf_dropout,
            batch_first=True,
        )

        self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)

        self._trajectory_head = TrajectoryHead(
            num_poses=trajectory_sampling.num_poses,
            d_ffn=config.tf_d_ffn,
            d_model=config.tf_d_model,
        )

    def init_weights(self):
        self._backbone.init_weights()

    def forward(
        self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """Torch module forward pass."""

        camera_feature: torch.Tensor = features["camera_feature"]
        camera_matrix: torch.Tensor = features["camera_matrix"]
        lidar_feature: torch.Tensor = features["lidar_feature"]
        status_feature: torch.Tensor = features["status_feature"]

        batch_size = status_feature.shape[0]

        gaussian_dict = self._backbone(
            camera_feature,
            lidar_feature,
            targets,
            camera_matrix,
        )

        gaussian_feature = gaussian_dict["gaussian"].features
        implicit_feature = gaussian_dict["gaussian"].im_features
        gaussian_pos = gaussian_dict["gaussian"].means

        bev_pos_embedding = self._bev_pos_encoding(gaussian_pos)
        bev_feature = self._bev_encoding(
            # gaussian_feature
            torch.cat([gaussian_feature, implicit_feature], dim=1)
        )
        status_encoding = self._status_encoding(status_feature)

        bev_feature[:, : gaussian_feature.shape[1]] = (
            bev_feature[:, : gaussian_feature.shape[1]] + bev_pos_embedding
        )
        keyval = torch.concatenate([bev_feature, status_encoding[:, None]], dim=1)

        query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
        query_out = self._tf_decoder(query, keyval)

        trajectory_query = query_out.split(self._query_splits, dim=1)[0]

        output = {}

        trajectory = self._trajectory_head(trajectory_query)
        output.update(trajectory)

        return output


class AgentHead(nn.Module):
    """Bounding box prediction head."""

    def __init__(
        self,
        num_agents: int,
        d_ffn: int,
        d_model: int,
    ):
        """
        Initializes prediction head.
        :param num_agents: maximum number of agents to predict
        :param d_ffn: dimensionality of feed-forward network
        :param d_model: input dimensionality
        """
        super(AgentHead, self).__init__()

        self._num_objects = num_agents
        self._d_model = d_model
        self._d_ffn = d_ffn

        self._mlp_states = nn.Sequential(
            nn.Linear(self._d_model, self._d_ffn),
            nn.ReLU(),
            nn.Linear(self._d_ffn, BoundingBox2DIndex.size()),
        )

        self._mlp_label = nn.Sequential(
            nn.Linear(self._d_model, 1),
        )

    def forward(self, agent_queries) -> Dict[str, torch.Tensor]:
        """Torch module forward pass."""

        agent_states = self._mlp_states(agent_queries)
        agent_states[..., BoundingBox2DIndex.POINT] = (
            agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
        )
        agent_states[..., BoundingBox2DIndex.HEADING] = (
            agent_states[..., BoundingBox2DIndex.HEADING].tanh() * np.pi
        )

        agent_labels = self._mlp_label(agent_queries).squeeze(dim=-1)

        return {"agent_states": agent_states, "agent_labels": agent_labels}


class TrajectoryHead(nn.Module):
    """Trajectory prediction head."""

    def __init__(self, num_poses: int, d_ffn: int, d_model: int):
        """
        Initializes trajectory head.
        :param num_poses: number of (x,y,θ) poses to predict
        :param d_ffn: dimensionality of feed-forward network
        :param d_model: input dimensionality
        """
        super(TrajectoryHead, self).__init__()

        self._num_poses = num_poses
        self._d_model = d_model
        self._d_ffn = d_ffn

        self._mlp = nn.Sequential(
            nn.Linear(self._d_model, self._d_ffn),
            nn.ReLU(),
            nn.Linear(self._d_ffn, num_poses * StateSE2Index.size()),
        )

    def forward(self, object_queries) -> Dict[str, torch.Tensor]:
        """Torch module forward pass."""
        poses = self._mlp(object_queries).reshape(
            -1, self._num_poses, StateSE2Index.size()
        )
        poses[..., StateSE2Index.HEADING] = (
            poses[..., StateSE2Index.HEADING].tanh() * np.pi
        )
        return {"trajectory": poses}
