"""
Implements a BEV sensor fusion backbone.
It uses simpleBEV to project camera features to BEV and then concatenates the features with the LiDAR.
"""

import torch
from torch import nn
import torch.nn.functional as F
from mmengine.config import Config
from mmdet.registry import MODELS
from .modules.blocks import linear_relu_ln
from .utils import transfuser_utils as t_u
from .modules.lidar_encoder import PointPreprocess
import timm


class BevEncoder(nn.Module):
    """
    Bev sensor Fusion
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        bev_cfg = Config.fromfile(config.bev_config)
        self.pc_range = bev_cfg.point_cloud_range
        self.real_w = self.pc_range[3] - self.pc_range[0]
        self.real_h = self.pc_range[4] - self.pc_range[1]

        self.embed_dims = bev_cfg.img_neck["out_channels"]
        self.img_backbone = timm.create_model(
            config.image_architecture,
            pretrained=True,
            out_indices=(2, 3, 4),
            features_only=True,
        )

        self.img_neck = MODELS.build(bev_cfg.img_neck)

        self.lidar_preprocess = PointPreprocess(bev_cfg)

        self.pts_backbone = timm.create_model(
            config.lidar_architecture,
            pretrained=False,
            in_chans=bev_cfg.num_filters[-1],
            out_indices=(2, 3, 4),
            features_only=True,
        )

        self.pts_neck = MODELS.build(bev_cfg.pts_neck)

        self.gaussian_init = MODELS.build(bev_cfg.gaussian_init)

        self.gaussian_encoder = MODELS.build(bev_cfg.gaussian_encoder)

        self.gaussian_decoder = MODELS.build(bev_cfg.gaussian_decoder)

        self._bev_encoding = nn.Sequential(
            *linear_relu_ln(256, 1, 1, bev_cfg.embed_dims),
        )

        self._bev_pos_encoding = nn.Sequential(
            *linear_relu_ln(256, 1, 1, 2), nn.Linear(256, 256)
        )

        self.init_weights()

        ego2image = nn.Parameter(
            t_u.calculate_ego2image_proj(self.config).to(torch.float).unsqueeze(0),
            requires_grad=False,
        )
        self.register_buffer(
            "projection_mat",
            ego2image,
            persistent=False,
        )
        image_shape = [self.config.cropped_height, self.config.cropped_width]
        self.metas = {}
        self.metas["img_wh"] = image_shape[::-1]
        bev_map_xy = self.get_bev_map_xy(bev_cfg)
        self.register_buffer(
            "bev_map_xy",
            bev_map_xy,
            persistent=False,
        )

    def init_weights(self):
        self.gaussian_init.init_weights()
        self.gaussian_encoder.init_weights()

    def get_bev_map_xy(self, config):
        W, H = config.bev_occ_size
        pc_range = config.pc_range
        voxel_size = config.voxel_size
        xs = (
            torch.linspace(
                pc_range[0] + voxel_size[0],
                pc_range[3] - voxel_size[0],
                W,
                dtype=torch.float,
            )
            .view(W, 1)
            .expand(W, H)
        )
        ys = (
            torch.linspace(
                pc_range[1] + voxel_size[1],
                pc_range[4] - voxel_size[1],
                H,
                dtype=torch.float,
            )
            .view(1, H)
            .expand(W, H)
        )
        bev_xy = torch.stack((xs, ys), -1)
        return bev_xy

    def forward(self, image, lidar, bev_semantic_label):
        """
        Image + LiDAR feature fusion in BEV
        """
        if self.config.normalize_imagenet:
            img_feats = t_u.normalize_imagenet(image)
        else:
            img_feats = image

        B, C, H, W = image.shape

        pts_feats = self.lidar_preprocess(
            points=lidar, batch_size=B, bev_semantic_label=bev_semantic_label
        )
    
        img_feature_list = self.img_backbone(img_feats)

        img_features_ms = self.img_neck(img_feature_list)

        pts_feature_list = self.pts_backbone(pts_feats)

        pts_features_ms = self.pts_neck(pts_feature_list)

        pts_bev_feature = pts_features_ms[0]
        gaussian_anchors, gaussian_features, implicit_features = self.gaussian_init(
            pts_bev_feature
        )

        img_features_ms_ = []
        pts_features_ms_ = []
        for i in range(len(pts_features_ms)):
            B, C, H, W = img_features_ms[i].shape
            img_features_ms_.append(img_features_ms[i].reshape(B, -1, C, H, W))
            pts_features_ms_.append(pts_features_ms[i].unsqueeze(1))

        self.metas["projection_mat"] = self.projection_mat
        gaussian_prediction = self.gaussian_encoder(
            gaussian_anchors,
            gaussian_features,
            implicit_features,
            img_features_ms_,
            pts_features_ms_,
            self.metas,
        )

        occ_xy = self.bev_map_xy.unsqueeze(0).repeat(B, 1, 1, 1)
        output = self.gaussian_decoder(
            **gaussian_prediction,
            occ_xy=occ_xy,
        )
        
        fused_features = self._bev_encoding(
            torch.cat(
                [output["gaussian"].features, output["gaussian"].im_features], dim=1
            )
        )

        gaussian_pos = output["gaussian"].means

        bev_pos_embedding = self._bev_pos_encoding(gaussian_pos)

        fused_features[:, : bev_pos_embedding.shape[1]] = (
            fused_features[:, : bev_pos_embedding.shape[1]] + bev_pos_embedding
        )
        return (
            output["pred_occ"][0],
            fused_features,
            img_features_ms[0],
            output["gaussian"],
        )