"""
Implements the TransFuser vision backbone.
"""

import copy
import math

import timm
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.registry import MODELS
from navsim.agents.gaussianfusion.transfuser_config import TransfuserConfig


class TransfuserBackbone(nn.Module):
    """Multi-scale Fusion Transformer for image + LiDAR feature fusion."""

    def __init__(self, config: TransfuserConfig):

        super().__init__()
        self.config = config

        self.metas = {}
    
        self.metas["img_wh"] = config.image_shape_raw[::-1]
       
        self.image_encoder = MODELS.build(config.img_encoder)

        self.image_neck = MODELS.build(config.img_neck)

        self.lidar_encoder = MODELS.build(config.lidar_encoder)

        self.lidar_neck = MODELS.build(config.lidar_neck)

        self.gaussian_init = MODELS.build(config.gaussian_init)

        self.gaussian_encoder = MODELS.build(config.gaussian_encoder)

        self.gaussian_decoder = MODELS.build(config.gaussian_decoder)

    def init_weights(self):
        self.image_encoder.init_weights()
        self.gaussian_init.init_weights()
        self.gaussian_encoder.init_weights()

    def forward(self, image, lidar, targets, camera_matrix):
        B, N, _, H, W = image.shape
        img_features = image.view(B * N, -1, H, W)
        pts_features = lidar.transpose(-2, -1)

        img_feature_list = self.image_encoder(img_features)

        img_features_ms = self.image_neck(img_feature_list)

        pts_feature_list = self.lidar_encoder(pts_features)

        pts_features_ms = self.lidar_neck(pts_feature_list)

        pts_bev_feature = pts_features_ms[0]
        gaussian_anchors, gaussian_features, implicit_features = self.gaussian_init(
            pts_bev_feature
        )

        img_features_ms_ = []
        pts_features_ms_ = []
        for i in range(len(pts_features_ms)):
            BN, C, H, W = img_features_ms[i].shape
            img_features_ms_.append(img_features_ms[i].reshape(B, N, C, H, W))
            pts_features_ms_.append(pts_features_ms[i].unsqueeze(1))

        self.metas["projection_mat"] = camera_matrix
        gaussian_prediction = self.gaussian_encoder(
            gaussian_anchors,
            gaussian_features,
            implicit_features,
            img_features_ms_,
            pts_features_ms_,
            self.metas,
        )

        occ_xy = targets["bev_occ_xy"]
        output = self.gaussian_decoder(
            **gaussian_prediction,
            occ_xy=occ_xy if targets is not None else None,
            occ_label=targets["bev_occ_label"] if targets is not None else None
        )
        output.update({"anchor_init": gaussian_anchors})

        return output
