from dataclasses import dataclass
from typing import Tuple

import numpy as np
from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType
from nuplan.common.maps.abstract_map import SemanticMapLayer


@dataclass
class TransfuserConfig:
    """Global TransFuser config."""

    image_architecture: str = "resnet34"
    lidar_architecture: str = "resnet34"

    plan_anchor_path: str = (
        "navsim_anchors/kmeans_navsim_traj_20.npy"
    )
    embed_dims = 128
    ######## preprocess lidar #########
    point_cloud_range = np.array([-32.0, -32.0, -2.0, 32.0, 32.0, 8.0])
    voxel_size = np.array([0.25, 0.25, 10])
    num_point_features: int = 3
    num_filters = [32, 32]
    #### raw lidar preprocess
    use_ground_plane = False
    lidar_split_height: float = 0.2
    max_height_lidar: float = 100.0
    pixels_per_meter: float = 4.0
    hist_max_per_pixel: int = 5
    ###########image encoder##############
    img_encoder = dict(
        type="ResNet",
        init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet34"),
        depth=34,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type="BN2d", requires_grad=True),
        # norm_eval=False,
        style="pytorch",
    )
    ####### image fpn ###########
    level_num = 3
    img_neck = dict(
        type="FPN",
        in_channels=[128, 256, 512],
        out_channels=embed_dims,
        start_level=0,
        add_extra_convs="on_output",
        num_outs=level_num,
        relu_before_extra_convs=True,
    )
    ####### lidar encoder #######
    lidar_encoder = dict(
        type="ResNet",
        depth=34,
        in_channels=1,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type="BN2d", requires_grad=True),
        # norm_eval=False,
        style="pytorch",
    )
    ####### lidar fpn ###########
    level_num = 3
    lidar_neck = dict(
        type="FPN",
        in_channels=[128, 256, 512],
        out_channels=embed_dims,
        start_level=0,
        add_extra_convs="on_output",
        num_outs=level_num,
        relu_before_extra_convs=True,
    )
    ############## Gaussian Initializer #########
    image_shape_raw = [1080, 1920]
    include_opa = False
    include_ele = False
    semantics = True
    semantic_dim = 7
    scale_range = [0.01, 3.2]
    xy_coordinate = "cartesian"
    phi_activation = "sigmoid"
    num_decoder = 4
    pc_range = point_cloud_range

    gaussian_init = dict(
        type="GaussianInit",
        num_anchor=0,
        embed_dims=embed_dims,
        anchor_grad=False,
        feat_grad=True,
        semantics=semantics,
        semantic_dim=semantic_dim,
        include_opa=include_opa,
        include_ele=include_ele,
        projection_in=None,
        random_samples=512,
    )
    ############# Gaussian Encoder ##############
    gaussian_encoder = dict(
        type="GaussianEncoder",
        anchor_encoder=dict(
            type="SparseGaussian3DEncoder",
            embed_dims=embed_dims,
            include_opa=include_opa,
            include_ele=include_ele,
            semantics=semantics,
            semantic_dim=semantic_dim,
        ),
        norm_layer=dict(type="LN", normalized_shape=embed_dims),
        ffn=dict(
            _delete_=True,
            type="AsymmetricFFN",
            in_channels=embed_dims,
            embed_dims=embed_dims,
            feedforward_channels=embed_dims * 4,
            ffn_drop=0.1,
            add_identity=False,
        ),
        deformable_model_pts=dict(
            type="DeformableFeatureAggregation_LiDAR",
            embed_dims=embed_dims,
            num_groups=4,
            num_levels=3,
            num_cams=1,
            attn_drop=0.15,
            use_deformable_func=True,
            use_camera_embed=False,
            residual_mode="none",
            kps_generator=dict(
                type="SparseGaussianKeyPointsGenerator_LiDAR",
                embed_dims=embed_dims,
                phi_activation=phi_activation,
                xyz_coordinate=xy_coordinate,
                num_learnable_pts=6,
                fix_scale=[
                    [0, 0],
                    [0.45, 0],
                    [-0.45, 0],
                    [0, 0.45],
                    [0, -0.45],
                ],
                pc_range=pc_range,
                scale_range=scale_range,
            ),
        ),
        deformable_model_img=dict(
            type="DeformableFeatureAggregation_CAM",
            embed_dims=embed_dims,
            num_groups=4,
            num_levels=3,
            num_cams=3,
            attn_drop=0.15,
            use_deformable_func=True,
            use_camera_embed=True,
            residual_mode="none",
            kps_generator=dict(
                type="SparseGaussianKeyPointsGenerator_CAM",
                embed_dims=embed_dims,
                phi_activation=phi_activation,
                xyz_coordinate=xy_coordinate,
                num_learnable_pts=6,
                fix_scale=[
                    [0, 0],
                    [0.45, 0],
                    [-0.45, 0],
                    [0, 0.45],
                    [0, -0.45],
                ],
                pc_range=pc_range,
                scale_range=scale_range,
            ),
        ),
        refine_layer=dict(
            type="SparseGaussianRefinementModuleV2",
            embed_dims=embed_dims,
            pc_range=pc_range,
            scale_range=scale_range,
            restrict_xy=False,
            unit_xy=[4.0, 4.0],
            refine_manual=None,
            phi_activation=phi_activation,
            semantics=semantics,
            semantic_dim=semantic_dim,
            include_opa=include_opa,
            include_ele=include_ele,
            xy_coordinate=xy_coordinate,
            semantics_activation="identity",
        ),
        # spconv_layer=dict(
        #     _delete_=True,
        #     type="SparseConv2D",
        #     in_channels=embed_dims,
        #     embed_channels=embed_dims,
        #     pc_range=pc_range,
        #     grid_size=[2.0, 2.0],
        #     phi_activation=phi_activation,
        #     xyz_coordinate=xy_coordinate,
        #     use_out_proj=True,
        #     use_multi_layer=True,
        # ),
        spconv_layer=dict(
            _delete_=True,
            type="GaussianAttention",
            embed_dims=embed_dims,
            num_head=4,
            dropout=0.15,
            batch_first=True,
        ),
        implicit_fusion=dict(
            _delete_=True,
            type="ImplicitFlattenFusion",
            embed_dims=embed_dims,
            num_groups=4,
            attn_drop=0.15,
            img_feature_size=[3, 8, 14],
            pts_feature_size=[8, 8],
        ),
        num_decoder=num_decoder,
        operation_order=[
            "identity",
            "deformable_pts",
            "add",
            "norm",
            "identity",
            "deformable_img",
            "add",
            "norm",
            "identity",
            "ffn",
            "add",
            "norm",
            "identity",
            "spconv",
            "add",
            "norm",
            "identity",
            "ffn",
            "add",
            "norm",
            "implicit_fusion",
            "refine",
        ]
        * num_decoder,
    )
    gaussian_decoder = dict(
        type="GaussianDecoder",
        apply_loss_type="random_1",
        num_classes=semantic_dim,
        empty_args=dict(
            # _delete_=True,
            mean=[0, 0, -1.0],
            scale=[100, 100, 8.0],
        ),
        with_empty=False,
        include_ele=include_ele,
        use_localaggprob=True,
        use_localaggprob_fast=True,
        combine_geosem=True,
        cuda_kwargs=dict(
            # _delete_=True,
            scale_multiplier=4,
            H=128,
            W=256,
            D=1,
            pc_min=[0, -32.0, 0.0],
            grid_size=0.25,
        ),
    )
    ###############################

    lidar_min_x: float = pc_range[0]
    lidar_max_x: float = pc_range[3]
    lidar_min_y: float = pc_range[1]
    lidar_max_y: float = pc_range[4]
    lidar_min_z: float = pc_range[2]
    lidar_max_z: float = pc_range[5]

    # new
    latent = False
    lidar_seq_len: int = 1
    block_exp = 4
    n_layer = 2  # Number of transformer layers used in the vision backbone
    n_head = 4
    n_scale = 4
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1
    # Mean of the normal distribution initialization for linear layers in the GPT
    gpt_linear_layer_init_mean = 0.0
    # Std of the normal distribution initialization for linear layers in the GPT
    gpt_linear_layer_init_std = 0.02
    # Initial weight of the layer norms in the gpt.
    gpt_layer_norm_init_weight = 1.0

    perspective_downsample_factor = 1
    transformer_decoder_join = True
    detect_boxes = True
    use_bev_semantic = True
    use_semantic = False
    use_depth = False
    add_features = True

    # Transformer
    tf_d_model: int = 256
    tf_d_ffn: int = 1024
    tf_num_layers: int = 3
    tf_num_head: int = 8
    tf_dropout: float = 0.0

    # detection
    num_bounding_boxes: int = 30

    # loss weights
    trajectory_weight: float = 20.0
    trajectory_cls_weight: float = 1.0
    trajectory_reg_weight: float = 0.8
    diff_loss_weight: float = 20.0
    agent_class_weight: float = 10.0
    agent_box_weight: float = 1.0
    bev_semantic_weight: float = 10.0

    # BEV mapping
    bev_semantic_classes = {
        1: ("polygon", [SemanticMapLayer.LANE, SemanticMapLayer.INTERSECTION]),  # road
        2: ("polygon", [SemanticMapLayer.WALKWAYS]),  # walkways
        3: (
            "linestring",
            [SemanticMapLayer.LANE, SemanticMapLayer.LANE_CONNECTOR],
        ),  # centerline
        4: (
            "box",
            [
                TrackedObjectType.CZONE_SIGN,
                TrackedObjectType.BARRIER,
                TrackedObjectType.TRAFFIC_CONE,
                TrackedObjectType.GENERIC_OBJECT,
            ],
        ),  # static_objects
        5: ("box", [TrackedObjectType.VEHICLE]),  # vehicles
        6: ("box", [TrackedObjectType.PEDESTRIAN]),  # pedestrians
    }

    bev_pixel_width: int = 256
    bev_pixel_height: int = 256
    bev_pixel_size: float = 0.25

    num_bev_classes = 7
    bev_features_channels: int = 64
    bev_down_sample_factor: int = 4
    bev_upsample_factor: int = 2

    # optmizer
    weight_decay: float = 1e-4
    lr_steps = [70]
    optimizer_type = "AdamW"
    scheduler_type = "MultiStepLR"
    cfg_lr_mult = 0.5
    opt_paramwise_cfg = {"name": {"image_encoder": {"lr_mult": cfg_lr_mult}}}

    @property
    def bev_semantic_frame(self) -> Tuple[int, int]:
        return (self.bev_pixel_height, self.bev_pixel_width)

    @property
    def bev_radius(self) -> float:
        values = [
            self.lidar_min_x,
            self.lidar_max_x,
            self.lidar_min_y,
            self.lidar_max_y,
        ]
        return max([abs(value) for value in values])
