# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------

import mmcv
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np

from mmcv.runner import force_fp32, auto_fp16
from mmdet.core import multi_apply
from mmdet.models import DETECTORS
from mmdet.models.builder import build_backbone
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
                          merge_aug_bboxes_3d, show_result)
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector

from projects.mmdet3d_plugin.models.feedforward_networks.moe import SparseMoE, load_balancing_loss_func
from projects.mmdet3d_plugin.utils import CUDATimer

from .cmt import CmtDetector
from typing import Any, Tuple


@DETECTORS.register_module()
class CmtDetectorWithTopPModalRouter(CmtDetector):

    def __init__(self,
                 top_p_threshold=0.9,
                 loss_router_dynamic_weight=1e-4,
                 loss_router_balance_weight=1e-2,
                 use_mask=False,
                 masked_img_path=None,
                 masked_img_feats_path=None,
                 masked_pts_feats_path=None,
                 always_use_lidar=False,
                 freeze_cmt=False,
                 version='v1',
                 **kwargs):
        super(CmtDetectorWithTopPModalRouter, self).__init__(**kwargs)
        self.always_use_lidar = always_use_lidar
        if self.always_use_lidar:
            num_classes = 5
        else:
            num_classes = 6
        self.router = nn.Linear(self.img_neck.out_channels, num_classes, bias=False)
        self.top_p_threshold = top_p_threshold
        self.loss_router_dynamic_weight = loss_router_dynamic_weight
        self.loss_router_balance_weight = loss_router_balance_weight

        self.use_mask = use_mask
        if self.use_mask:
            if masked_img_path is None:
                raise ValueError('masked_img_path must be specified if using mask')

            self.masked_img = torch.load(masked_img_path, map_location='cpu')
        else:
            self.masked_img = None

        self.masked_img_feats = None
        if masked_img_feats_path is not None:
            self.masked_img_feats = torch.load(masked_img_feats_path, map_location='cpu')
        else:
            self.masked_img_feats = None

        self.masked_pts_feats = None
        if masked_pts_feats_path is not None:
            self.masked_pts_feats = torch.load(masked_pts_feats_path, map_location='cpu')
        else:
            self.masked_pts_feats = None

        self.version = version
        assert self.version in ['v1']

        if freeze_cmt:
            for p in self.parameters():
                p.requires_grad = False
            for p in self.router.parameters():
                p.requires_grad = True

    def mask_feats(self,
                   img_feats,
                   pts_feats,
                   img_metas,
                   feat_weights,
                   feat_weights_original,
                   masked_img_feats=None,
                   masked_pts_feats=None):
        B = feat_weights.shape[0]

        # never mask CAM_FRONT
        cam_front_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
        feat_weights = torch.concat([cam_front_weight, feat_weights], dim=1)

        if self.always_use_lidar:
            lidar_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
            feat_weights = torch.concat([feat_weights, lidar_weight], dim=1)

        num_img_feats = feat_weights.shape[1] - 1
        img_feat_weights, pts_feats_weight = feat_weights.split([num_img_feats, 1], dim=1)

        N = img_feats[0].shape[0] // B
        if not self.use_mask:
            img_feats_new = ()
            for img_feat in img_feats:
                if self.version == 'v1':
                    BN, D, H, W = img_feat.shape
                    img_feat = img_feat.reshape(B, N, D, H, W)
                    img_feat = img_feat * img_feat_weights[:, :, None, None, None]
                    img_feat = img_feat.reshape(BN, D, H, W)
                    img_feats_new += (img_feat,)
                else:
                    raise NotImplementedError()

            pts_feats_new = []
            for pts_feat in pts_feats:
                pts_feat = pts_feat * pts_feats_weight[:, :, None, None]
                pts_feats_new.append(pts_feat)

            # set valid_imgs and valid_points so that they will be masked in cross attention
            for idx, meta in enumerate(img_metas):
                meta['valid_imgs'] = img_feat_weights[idx].to(torch.bool)
                meta['valid_points'] = pts_feats_weight[idx, 0].to(torch.bool).item()

        else:
            assert masked_img_feats is not None and masked_pts_feats is not None
            if self.version != 'v1':
                raise NotImplementedError('only v1 is supported if using mask')

            img_feats_new = ()
            img_feat_weights = img_feat_weights[:, :, None, None, None]
            for img_feat, masked_img_feat in zip(img_feats, masked_img_feats):
                BN, D, H, W = img_feat.shape
                img_feat = img_feat.reshape(B, N, D, H, W)
                masked_img_feat = masked_img_feat.unsqueeze(0).expand(B, N, D, H, W)

                img_feat_new = img_feat * img_feat_weights + masked_img_feat * (1.0 - img_feat_weights)
                img_feat_new = img_feat_new.reshape(BN, D, H, W)
                img_feats_new += (img_feat_new,)

            pts_feats_new = []
            pts_feats_weight = pts_feats_weight[:, :, None, None]
            for pts_feat, masked_pts_feat in zip(pts_feats, masked_pts_feats):
                pts_feat_new = pts_feat * pts_feats_weight + masked_pts_feat * (1.0 - pts_feats_weight)
                pts_feats_new.append(pts_feat_new)

        return img_feats_new, pts_feats_new, img_metas

    def router_forward(self, img, img_feats):
        if self.version == 'v1':
            return self.router_forward_v1(img, img_feats)
        else:
            raise NotImplementedError(f'Unsupported router version: {self.version}')

    def router_forward_v1(self, img, img_feats):
        if img.dim() == 4:
            B = 1
            N, _, _, _ = img.shape
        else:
            B, N, _, _, _ = img.shape

        _, C, H, W = img_feats[0].shape
        router_inputs = img_feats[0].reshape(B, N, C, H, W)
        router_inputs = router_inputs[:, 0] # features of FRONT_CAM, (B, C, H, W)
        router_inputs = F.adaptive_avg_pool2d(router_inputs, output_size=(1, 1)) # (B, C, 1, 1)
        router_inputs = router_inputs[:, :, 0, 0] # (B, C)

        logits = self.router(router_inputs)
        weights = F.softmax(logits, dim=1)
        weights_sampled = top_p_sampling(weights, self.top_p_threshold, with_grad=True)

        return weights_sampled, weights

    def router_forward_v2(self, img, img_feats):
        if img.dim() == 4:
            B = 1
            N, _, _, _ = img.shape
        else:
            B, N, _, _, _ = img.shape

        _, C, H, W = img_feats[0].shape
        router_inputs = img_feats[0].reshape(B, N, C, H, W)
        router_inputs = router_inputs[:, 0] # features of FRONT_CAM, (B, C, H, W)
        router_inputs = F.adaptive_avg_pool2d(router_inputs, output_size=(1, 1)) # (B, C, 1, 1)
        router_inputs = router_inputs[:, :, 0, 0] # (B, C)

        logits = self.router(router_inputs)
        weights = F.softmax(logits, dim=1)
        weights_sampled = top_p_sampling(weights, self.top_p_threshold, with_grad=False)

        return weights_sampled, weights

    def dynamic_loss(self, weights):
        return -(weights * torch.log(weights)).sum(dim=1).mean()

    def load_balancing_loss(self, weights: torch.Tensor, weights_sampled: torch.Tensor):
        assert (weights_sampled == 0.).sum() + (weights_sampled == 1.).sum() == weights_sampled.numel()

        _, num_experts = weights.shape
        expert_mask = weights_sampled.int()
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
        router_prob_per_expert = torch.mean(weights, dim=0)
        loss = torch.sum(tokens_per_expert * router_prob_per_expert) * num_experts
        return loss

    def count_sensor(self, feat_weights):
        with torch.no_grad():
            B = feat_weights.shape[0]

            # never mask CAM_FRONT
            cam_front_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
            feat_weights = torch.concat([cam_front_weight, feat_weights], dim=1)
            if self.always_use_lidar:
                lidar_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
                feat_weights = torch.concat([feat_weights, lidar_weight], dim=1)
            feat_weights = feat_weights.to(torch.float32)
            sensor_cnt = feat_weights.sum(dim=0)
            return sensor_cnt

    def extract_feat_with_mask(self, points, masked_points, img, img_metas):
        B = img.shape[0]
        merged_points = points + masked_points
        masked_img = self.masked_img.to(device=img.device, dtype=img.dtype).unsqueeze(0)
        merged_img = torch.concat([img, masked_img], dim=0)

        merged_img_feats, merged_pts_feats = self.extract_feat(
            merged_points, img=merged_img, img_metas=img_metas)

        pts_feats, masked_pts_feats = [], []
        for merged_pts_feat in merged_pts_feats:
            pts_feats.append(merged_pts_feat[:B])
            masked_pts_feats.append(merged_pts_feat[B:])

        img_feats, masked_img_feats = (), ()
        for merged_img_feat in merged_img_feats:
            N = merged_img_feat.shape[0] // (B + 1)
            img_feats += (merged_img_feat[:B*N],)
            masked_img_feats += (merged_img_feat[B*N:],)

        return img_feats, pts_feats, masked_img_feats, masked_pts_feats

    def forward_train(self,
                      points=None,
                      masked_points=None,
                      img_metas=None,
                      gt_bboxes_3d=None,
                      gt_labels_3d=None,
                      gt_labels=None,
                      gt_bboxes=None,
                      img=None,
                      proposals=None,
                      gt_bboxes_ignore=None):
        """Forward training function.

        Args:
            points (list[torch.Tensor], optional): Points of each sample.
                Defaults to None.
            img_metas (list[dict], optional): Meta information of each sample.
                Defaults to None.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
                Ground truth 3D boxes. Defaults to None.
            gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
                of 3D boxes. Defaults to None.
            gt_labels (list[torch.Tensor], optional): Ground truth labels
                of 2D boxes in images. Defaults to None.
            gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
                images. Defaults to None.
            img (torch.Tensor optional): Images of each sample with shape
                (N, C, H, W). Defaults to None.
            proposals ([list[torch.Tensor], optional): Predicted proposals
                used for training Fast RCNN. Defaults to None.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                2D boxes in images to be ignored. Defaults to None.

        Returns:
            dict: Losses of different branches.
        """

        masked_img_feats, masked_pts_feats = None, None
        if self.use_mask:
            img_feats, pts_feats, masked_img_feats, masked_pts_feats = self.extract_feat_with_mask(
                points, masked_points=masked_points, img=img, img_metas=img_metas)
        else:
            img_feats, pts_feats = self.extract_feat(
                points, img=img, img_metas=img_metas)

        weights_sampled, weights = self.router_forward(img, img_feats)

        # weights_sampled = torch.ones_like(weights_sampled) # just for debug

        img_feats, pts_feats, img_metas = self.mask_feats(img_feats, pts_feats, img_metas, weights_sampled, weights,
                                                          masked_img_feats, masked_pts_feats)

        losses = dict()
        if pts_feats or img_feats:
            losses_pts = self.forward_pts_train(pts_feats, img_feats, gt_bboxes_3d,
                                                gt_labels_3d, img_metas,
                                                gt_bboxes_ignore)
            losses.update(losses_pts)

            loss_load_balancing = self.compute_load_balancing_loss()
            if loss_load_balancing is not None:
                losses.update({'loss_load_balancing': loss_load_balancing})

        loss_router_dynamic = self.dynamic_loss(weights) * self.loss_router_dynamic_weight
        loss_router_balance = self.load_balancing_loss(weights, weights_sampled) * self.loss_router_balance_weight
        losses.update({'loss_router_dynamic': loss_router_dynamic, 'loss_router_balance': loss_router_balance})

        with torch.no_grad():
            sensor_cnt = self.count_sensor(weights_sampled)
            dist.all_reduce(sensor_cnt)
            global_batch_size = img.shape[0] * dist.get_world_size()
            sensor_ratio = sensor_cnt / global_batch_size
            for i in range(len(sensor_ratio)):
                losses.update({f'sensor.{i}.ratio': sensor_ratio[i]})

        return losses

    def forward_test(self,
                     points=None,
                     masked_points=None,
                     img_metas=None,
                     img=None, **kwargs):
        """
        Args:
            points (list[torch.Tensor]): the outer list indicates test-time
                augmentations and inner torch.Tensor should have a shape NxC,
                which contains all points in the batch.
            img_metas (list[list[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
            img (list[torch.Tensor], optional): the outer
                list indicates test-time augmentations and inner
                torch.Tensor should have a shape NxCxHxW, which contains
                all images in the batch. Defaults to None.
        """
        if points is None:
            points = [None]
        if img is None:
            img = [None]
        if masked_points is None:
            masked_points = [None]
        for var, name in [(points, 'points'), (img, 'img'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        return self.simple_test(points[0], masked_points[0], img_metas[0], img[0], **kwargs)

    def simple_test(self, points, masked_points, img_metas, img=None, rescale=False):
        if self.force_no_lidar:
            dtype = points[0].dtype
            device = points[0].device
            points = [torch.zeros(1, 5, dtype=dtype, device=device)]

        if self.force_empty_image:
            img = torch.zeros_like(img)

        masked_img_feats, masked_pts_feats = None, None
        if self.use_mask:
            img_feats, pts_feats, masked_img_feats, masked_pts_feats = self.extract_feat_with_mask(
                points, masked_points=masked_points, img=img, img_metas=img_metas)
        else:
            img_feats, pts_feats = self.extract_feat(
                points, img=img, img_metas=img_metas)

        assert pts_feats is not None and img_feats is not None
        weights_sampled, weights = self.router_forward(img, img_feats)
        img_feats, pts_feats, img_metas = self.mask_feats(img_feats, pts_feats, img_metas, weights_sampled,
                                                          masked_img_feats, masked_pts_feats)
        
        bbox_list = [dict() for i in range(len(img_metas))]
        if (pts_feats or img_feats) and self.with_pts_bbox:
            bbox_pts = self.simple_test_pts(
                pts_feats, img_feats, img_metas, rescale=rescale)
            for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                result_dict['pts_bbox'] = pts_bbox
        if img_feats and self.with_img_bbox:
            bbox_img = self.simple_test_img(
                img_feats, img_metas, rescale=rescale)
            for result_dict, img_bbox in zip(bbox_list, bbox_img):
                result_dict['img_bbox'] = img_bbox
        return bbox_list

    def forward_test_sensor_ratio(self, img, img_metas):
        img_feats = self.extract_img_feat(img, img_metas)
        weights_sampled, _ = self.router_forward(img, img_feats)

        sensor_cnt = self.count_sensor(weights_sampled)
        batch_size = img.shape[0]
        sensor_ratio = sensor_cnt / batch_size
        return sensor_ratio, sensor_cnt

    def simple_test_time(self, points, img_metas, img=None, rescale=False):
        assert img.shape[0] == 1
        assert self.masked_img_feats is not None
        assert self.masked_pts_feats is not None

        masked_img_feats = ()
        masked_pts_feats = []
        for i in range(len(self.masked_img_feats)):
            masked_img_feats += (self.masked_img_feats[i].cuda(),)
        for i in range(len(self.masked_pts_feats)):
            masked_pts_feats.append(self.masked_pts_feats[i].cuda())
        self.masked_img_feats = masked_img_feats
        self.masked_pts_feats = masked_pts_feats

        time_dict = dict()

        if self.force_no_lidar:
            with CUDATimer('mask_lidar_force') as t:
                dtype = points[0].dtype
                device = points[0].device
                points = [torch.zeros(1, 5, dtype=dtype, device=device)]
            time_dict = self._merge_dict(time_dict, t.get_time_dict())

        if self.force_empty_image:
            with CUDATimer('mask_img_force') as t:
                img = torch.zeros_like(img)
            time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('extract_feat_cam_front') as t:
            img_cam_front = img[:, 0:1]
            img_cam_front_feat = self.extract_img_feat(img_cam_front, img_metas)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('router') as t:
            router_inputs = F.adaptive_avg_pool2d(img_cam_front_feat[0], output_size=(1, 1)) # (B, C, 1, 1)
            router_inputs = router_inputs[:, :, 0, 0] # (B, C)
            logits = self.router(router_inputs)
            weights = F.softmax(logits, dim=1)
            mask = top_p_sampling(weights, self.top_p_threshold)[0].to(torch.bool)
            camera_mask = torch.zeros(6, dtype=torch.bool, device=mask.device)
            camera_mask[1:] = mask[:-1] # camera_mask[0] is always False
            use_lidar = mask[-1].item()
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('extract_feat_cam_other') as t:
            img_cam_other = img[:, camera_mask]
            img_cam_other_feat = self.extract_img_feat(img_cam_other, img_metas)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('construct_cam_feat') as t:
            img_feats = ()
            for i in range(len(img_cam_front_feat)):
                img_feat = [img_cam_front_feat[i][0]]
                cam_other_idx = 0
                for use_cam in camera_mask[1:]:
                    if use_cam:
                        img_feat.append(img_cam_other_feat[i][cam_other_idx])
                        cam_other_idx += 1
                    else:
                        img_feat.append(masked_img_feats[i][0])
                img_feat = torch.stack(img_feat, dim=0)
                img_feats += (img_feat,)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        if use_lidar:
            with CUDATimer('extract_lidar_feat') as t:
                pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
            time_dict = self._merge_dict(time_dict, t.get_time_dict())
        else:
            pts_feats = masked_pts_feats
            time_dict['extract_lidar_feat'] = 0.0

        if pts_feats is None:
            pts_feats = [None]
        if img_feats is None:
            img_feats = [None]

        if (pts_feats or img_feats) and self.with_pts_bbox:
            time_dict_pts = self.simple_test_pts_time(
                pts_feats, img_feats, img_metas, rescale=rescale)
            time_dict = self._merge_dict(time_dict, time_dict_pts)
        else:
            raise NotImplementedError()

        sensor_mask = torch.ones(7, dtype=torch.bool, device=mask.device)
        sensor_mask[1:] = mask
        sensor_mask = sensor_mask.tolist()

        return {'time_dict': time_dict, 'sensor_mask': sensor_mask}


def top_p_sampling(logits, top_p=0.9, temperature=1.0, with_grad=True, at_least_one=False):
    """
    Apply Top-p sampling to every element in the sequence for each item in the batch.
    Returns the selected token indices and the corresponding threshold indices.
    
    :param logits: Logits from a language model with shape (sequence length, batch size, L)
    :param top_p: Cumulative probability threshold (float)
    :param temperature: Sampling temperature (float)
    :return: Tuple of tensors (selected token indices, threshold indices) for each position in each sequence in the batch
    """
    # Apply temperature
    logits_t = logits / temperature
    
    # Convert logits to probabilities
    # probabilities = torch.softmax(logits, dim=-1)
    # Sort probabilities and their indices in descending order
    sorted_probs, sorted_indices = torch.sort(logits_t, descending=True)

    # Compute cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    mask = cumulative_probs > top_p

    if at_least_one:
        mask[:, 0] = False

    batch_indices = torch.arange(logits.shape[0], device=logits.device).unsqueeze(1)
    inverse_indices = torch.argsort(sorted_indices)
    mask = mask[batch_indices, inverse_indices]

    if with_grad:
        logits = ones_like_with_grad(logits)
    else:
        logits = torch.ones_like(logits)
    return logits * torch.logical_not(mask)

class OnesLikeWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor):
        output = torch.ones_like(input_tensor)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

def ones_like_with_grad(input_tensor):
    return OnesLikeWithGrad.apply(input_tensor)
