import torch
import torch.nn as nn
from .roi_head_template import RoIHeadTemplate
from ...utils import common_utils, loss_utils
from pcdet.config import cfg

class SECONDHead(RoIHeadTemplate):
    def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
        super().__init__(num_class=num_class, model_cfg=model_cfg)
        self.model_cfg = model_cfg

        GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
        pre_channel = self.model_cfg.ROI_GRID_POOL.IN_CHANNEL * GRID_SIZE * GRID_SIZE

        shared_fc_list = []
        for k in range(0, self.model_cfg.SHARED_FC.__len__()):
            shared_fc_list.extend([
                nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False),
                nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
                nn.ReLU()
            ])
            pre_channel = self.model_cfg.SHARED_FC[k]

            if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
                shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))

        self.shared_fc_layer = nn.Sequential(*shared_fc_list)

        self.iou_layers = self.make_fc_layers(
            input_channels=pre_channel, output_channels=1, fc_list=self.model_cfg.IOU_FC
        )
        self.init_weights(weight_init='xavier')

    def init_weights(self, weight_init='xavier'):
        if weight_init == 'kaiming':
            init_func = nn.init.kaiming_normal_
        elif weight_init == 'xavier':
            init_func = nn.init.xavier_normal_
        elif weight_init == 'normal':
            init_func = nn.init.normal_
        else:
            raise NotImplementedError

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                if weight_init == 'normal':
                    init_func(m.weight, mean=0, std=0.001)
                else:
                    init_func(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def roi_grid_pool(self, batch_dict):
        """
        Args:
            batch_dict:
                batch_size:
                rois: (B, num_rois, 7 + C)
                spatial_features_2d: (B, C, H, W)
        Returns:

        """
        batch_size = batch_dict['batch_size']
        rois = batch_dict['rois'].detach()
        spatial_features_2d = batch_dict['spatial_features_2d'].detach()
        height, width = spatial_features_2d.size(2), spatial_features_2d.size(3)

        dataset_cfg = batch_dict['dataset_cfg']
        min_x = dataset_cfg.POINT_CLOUD_RANGE[0]
        min_y = dataset_cfg.POINT_CLOUD_RANGE[1]
        voxel_size_x = dataset_cfg.DATA_PROCESSOR[-1].VOXEL_SIZE[0]
        voxel_size_y = dataset_cfg.DATA_PROCESSOR[-1].VOXEL_SIZE[1]
        down_sample_ratio = self.model_cfg.ROI_GRID_POOL.DOWNSAMPLE_RATIO

        pooled_features_list = []
        torch.backends.cudnn.enabled = False
        for b_id in range(batch_size):
            # Map global boxes coordinates to feature map coordinates
            x1 = (rois[b_id, :, 0] - rois[b_id, :, 3] / 2 - min_x) / (voxel_size_x * down_sample_ratio)
            x2 = (rois[b_id, :, 0] + rois[b_id, :, 3] / 2 - min_x) / (voxel_size_x * down_sample_ratio)
            y1 = (rois[b_id, :, 1] - rois[b_id, :, 4] / 2 - min_y) / (voxel_size_y * down_sample_ratio)
            y2 = (rois[b_id, :, 1] + rois[b_id, :, 4] / 2 - min_y) / (voxel_size_y * down_sample_ratio)

            angle, _ = common_utils.check_numpy_to_torch(rois[b_id, :, 6])

            cosa = torch.cos(angle)
            sina = torch.sin(angle)

            theta = torch.stack((
                (x2 - x1) / (width - 1) * cosa, (x2 - x1) / (width - 1) * (-sina), (x1 + x2 - width + 1) / (width - 1),
                (y2 - y1) / (height - 1) * sina, (y2 - y1) / (height - 1) * cosa, (y1 + y2 - height + 1) / (height - 1)
            ), dim=1).view(-1, 2, 3).float()

            grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
            grid = nn.functional.affine_grid(
                theta,
                torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size))
            )

            pooled_features = nn.functional.grid_sample(
                spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width),
                grid
            )

            pooled_features_list.append(pooled_features)

        torch.backends.cudnn.enabled = True
        pooled_features = torch.cat(pooled_features_list, dim=0)

        return pooled_features

    def forward(self, batch_dict):
        """
        :param input_data: input dict
        :return:
        """
        targets_dict = self.proposal_layer(
            batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
        )
        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('IE_AUG', None) and self.training:
            # 从 targets_dict 中取出 proposal 信息，形状均为 [B, N, D] 或 [B, N]（GPU tensor）
            pred_boxes_all = targets_dict['rois']              # [B, N, D]
            pred_labels_all = targets_dict['roi_labels']         # [B, N]
            pred_scores_all = targets_dict['roi_scores']         # [B, N]

            pred_boxes_pre_nms_all = targets_dict['rois_pre_nms']        # [B, M, D]
            # 其他 pre_nms 信息如果需要也可以取出来

            B = pred_boxes_all.shape[0]  # batch size

            new_boxes_list = []
            new_labels_list = []
            new_scores_list = []

            from pcdet.ops.iou3d_nms import iou3d_nms_utils

            for b in range(B):
                # 取出第 b 帧的 proposals（保留梯度）
                pred_boxes = pred_boxes_all[b]     # [N, D]
                pred_labels = pred_labels_all[b]   # [N]
                pred_scores = pred_scores_all[b]   # [N]

                # 对应 pre_nms proposals，第 b 帧的 shape 为 [M, D]
                pred_boxes_pre_nms = pred_boxes_pre_nms_all[b]

                # 计算 IoU（按 BEV，取前7个维度）：
                # iou_matrix 的 shape 为 [M, N]
                iou_matrix = iou3d_nms_utils.boxes_iou_bev(pred_boxes_pre_nms[:, :7],
                                                        pred_boxes[:, :7])
                # 删除那些与任一 pred_box IoU > 0.999 的 pre_nms box
                # high_idx: [M] 布尔向量，若某 pre_nms box 与任一 pred_box IoU 大于0.999，则 True
                high_idx = (iou_matrix > 0.999).any(dim=1)
                # 保留符合条件的 pre_nms boxes
                pred_boxes_pre_nms_keep = pred_boxes_pre_nms[~high_idx]  # [M_keep, D]
                iou_matrix_keep = iou_matrix[~high_idx]                  # [M_keep, N]

                # 如果有剩余的 pre_nms boxes，则生成额外 proposal
                if iou_matrix_keep.size(0) > 0:
                    # 对于 pred_boxes 中的每个 box（列），找到与之 IoU 最大的 pre_nms box
                    # closest_idx 的 shape 为 [N]
                    closest_idx = torch.argmax(iou_matrix_keep, dim=0)
                    # 从保留的 pre_nms boxes 中选取对应的 proposal，形状 [N, D]
                    closest_proposal = pred_boxes_pre_nms_keep[closest_idx]

                    # 根据 pred_boxes 和 closest_proposal 生成两组额外 proposal：
                    # 1. 插值 proposal：取 pred_box 与 closest_proposal 中心位置的差异（示例公式，可根据实际需求调整）
                    inter_center = pred_boxes[:, :3] - (pred_boxes[:, :3] - closest_proposal[:, :3]) / 2
                    # inter_center = (closest_proposal[:, :3] + pred_boxes[:, :3]) / 2  # 修正：计算中点为插值
                    inter_other = pred_boxes[:, 3:]
                    inter_proposal = torch.cat([inter_center, inter_other], dim=1)  # [N, D]

                    # 2. 外推 proposal：在 pred_box 基础上加上中心位置的平均值
                    extra_center = pred_boxes[:, :3] + (pred_boxes[:, :3] - closest_proposal[:, :3]) / 2
                    # diff_vector = pred_boxes[:, :3] - closest_proposal[:, :3]  # 计算从closest到pred_box的向量
                    extra_other = pred_boxes[:, 3:]
                    extra_proposal = torch.cat([extra_center, extra_other], dim=1)  # [N, D]

                    # 将原始 proposals 与两组额外 proposal 拼接
                    pred_boxes_aug = torch.cat([pred_boxes, inter_proposal, extra_proposal], dim=0)  # [3N, D]
                    # 对于 label 和 score，直接复制（假设额外 proposal 与原 proposal 标签、分数一致）
                    pred_labels_aug = torch.cat([pred_labels, pred_labels, pred_labels], dim=0)    # [3N]
                    pred_scores_aug = torch.cat([pred_scores, pred_scores, pred_scores], dim=0)    # [3N]
                else:
                    # 若没有生成额外 proposal，则保持原状
                    pred_boxes_aug = pred_boxes
                    pred_labels_aug = pred_labels
                    pred_scores_aug = pred_scores

                # 将当前帧处理后的结果添加到列表中
                new_boxes_list.append(pred_boxes_aug)
                new_labels_list.append(pred_labels_aug)
                new_scores_list.append(pred_scores_aug)

            # 注意：各帧经过额外 proposal 生成后，proposal 数量应当保持一致（例如都为 3N）以便后续堆叠，
            # 否则需要使用 pad 或其他方式做对齐。
            new_boxes = torch.stack(new_boxes_list, dim=0)    # [B, num_prop, D]
            new_labels = torch.stack(new_labels_list, dim=0)    # [B, num_prop]
            new_scores = torch.stack(new_scores_list, dim=0)    # [B, num_prop]

            # 更新 targets_dict 中 proposal 信息
            targets_dict['rois'] = new_boxes
            targets_dict['roi_labels'] = new_labels
            targets_dict['roi_scores'] = new_scores

            
        # if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('CONSISTENCY', None) and cfg.SELF_TRAIN.CONSISTENCY.get('before', None):
        if cfg.get('SELF_TRAIN', None):
            # targets_dict['rois_pre_nms'] = batch_dict['rois'].clone()
            # 分配 GT 前的 ROI
            batch_dict['rois_mt'] = targets_dict['rois']
            batch_dict['roi_labels_mt'] = targets_dict['roi_labels']
            batch_dict['roi_scores_mt'] = targets_dict['roi_scores']
            batch_dict['roi_cls_preds_mt'] = targets_dict['roi_cls_preds']
            # batch_dict['roi_labels_mt'] = targets_dict['roi_labels']
            # if cfg.SELF_TRAIN.CONSISTENCY.get('before_no_grad', None):
            with torch.no_grad():
                pooled_features = self.roi_grid_pool(batch_dict)
                batch_size_rcnn = pooled_features.shape[0]
                shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
                rcnn_iou = self.iou_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B*N, 1)
                batch_dict['roi_head_features_mt'] = shared_features.view(batch_dict['batch_size'], -1, shared_features.shape[-2])
                batch_dict['roi_iou_scores_mt'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1])
            # else:
            #     pooled_features = self.roi_grid_pool(batch_dict)  # (BxN, C, 7, 7)
            #     batch_size_rcnn = pooled_features.shape[0]
            #     shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
            #     rcnn_iou = self.iou_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B*N, 1)
            #     batch_dict['roi_head_features_mt'] = shared_features.view(batch_dict['batch_size'], -1, shared_features.shape[-2])
            #     batch_dict['roi_iou_scores_mt'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1])

        if self.training:
            targets_dict = self.assign_targets(batch_dict)
            batch_dict['rois'] = targets_dict['rois']
            batch_dict['roi_labels'] = targets_dict['roi_labels']
            batch_dict['roi_scores'] = targets_dict['roi_scores']

        #TODO: 分配 GT 后的 ROI， 考虑在分配之前还是之后
        batch_dict['rois'] = targets_dict['rois']
        batch_dict['roi_labels'] = targets_dict['roi_labels']
        batch_dict['roi_scores'] = targets_dict['roi_scores']
        # batch_dict['roi_labels'] = targets_dict['roi_labels']

        # RoI aware pooling
        pooled_features = self.roi_grid_pool(batch_dict)  # (BxN, C, 7, 7)
        batch_size_rcnn = pooled_features.shape[0]

        shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
        rcnn_iou = self.iou_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B*N, 1)
        
        # #TODO 分配 GT 后的 BEV box features，考虑在 shared_fc_layer 之前还是之后
        batch_dict['roi_head_features'] = shared_features.view(batch_dict['batch_size'], -1, shared_features.shape[-2])
        batch_dict['roi_iou_scores'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1])


        if batch_dict.get('rois_mt', None) is None:
            # 分配 GT 前的 ROI
            batch_dict['rois_mt'] = targets_dict['rois'].clone()
            batch_dict['roi_labels_mt'] = targets_dict['roi_labels'].clone()
            batch_dict['roi_scores_mt'] = targets_dict['roi_scores'].clone()
            batch_dict['roi_head_features_mt'] = batch_dict['roi_head_features'].clone()
            batch_dict['roi_iou_scores_mt'] = batch_dict['roi_iou_scores'].clone()
        # batch_dict['roi_labels_mt'] = targets_dict['roi_labels']

        # pooled_features = self.roi_grid_pool(batch_dict)  # (BxN, C, 7, 7)
        # batch_size_rcnn = pooled_features.shape[0]
        # shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
        # rcnn_iou = self.iou_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B*N, 1)
        # batch_dict['roi_head_features_mt'] = shared_features.view(batch_dict['batch_size'], -1, shared_features.shape[-2])
        # batch_dict['roi_iou_scores_mt'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1])
        
        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('IE_AUG', None) or (cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('REDB', None)):
            # boxes_pre_nms RoI aware pooling
            tmp_batch_dict = batch_dict.copy()
            tmp_batch_dict['rois'] = batch_dict['rois_pre_nms']

            tmp_pooled_features = self.roi_grid_pool(tmp_batch_dict)  # (BxN, C, 7, 7)
            tmp_batch_size_rcnn = tmp_pooled_features.shape[0]

            tmp_shared_features = self.shared_fc_layer(tmp_pooled_features.view(tmp_batch_size_rcnn, -1, 1))
            tmp_rcnn_iou = self.iou_layers(tmp_shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B*N, 1)

        if not self.training:
            batch_dict['batch_cls_preds'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1])
            batch_dict['batch_box_preds'] = batch_dict['rois']

            if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('IE_AUG', None) or (cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('REDB', None)):
                batch_dict['batch_cls_preds_pre_nms'] = tmp_rcnn_iou.view(tmp_batch_dict['batch_size'], -1, tmp_rcnn_iou.shape[-1])
                batch_dict['batch_box_preds_pre_nms'] = batch_dict['rois_pre_nms']
            
            batch_dict['cls_preds_normalized'] = False
        else:
            targets_dict['rcnn_iou'] = rcnn_iou

            self.forward_ret_dict = targets_dict

        return batch_dict

    def get_loss(self, tb_dict=None):
        tb_dict = {} if tb_dict is None else tb_dict
        rcnn_loss = 0
        rcnn_loss_cls, cls_tb_dict = self.get_box_iou_layer_loss(self.forward_ret_dict)
        rcnn_loss += rcnn_loss_cls
        tb_dict.update(cls_tb_dict)

        tb_dict['rcnn_loss'] = rcnn_loss.item()
        return rcnn_loss, tb_dict

    def get_box_iou_layer_loss(self, forward_ret_dict):
        loss_cfgs = self.model_cfg.LOSS_CONFIG
        rcnn_iou = forward_ret_dict['rcnn_iou']
        rcnn_iou_labels = forward_ret_dict['rcnn_cls_labels'].view(-1)
        rcnn_iou_flat = rcnn_iou.view(-1)
        if loss_cfgs.IOU_LOSS == 'BinaryCrossEntropy':
            batch_loss_iou = nn.functional.binary_cross_entropy_with_logits(
                rcnn_iou_flat,
                rcnn_iou_labels.float(), reduction='none'
            )
        elif loss_cfgs.IOU_LOSS == 'L2':
            batch_loss_iou = nn.functional.mse_loss(rcnn_iou_flat, rcnn_iou_labels, reduction='none')
        elif loss_cfgs.IOU_LOSS == 'smoothL1':
            diff = rcnn_iou_flat - rcnn_iou_labels
            batch_loss_iou = loss_utils.WeightedSmoothL1Loss.smooth_l1_loss(diff, 1.0 / 9.0)
        elif loss_cfgs.IOU_LOSS == 'focalbce':
            batch_loss_iou = loss_utils.sigmoid_focal_cls_loss(rcnn_iou_flat, rcnn_iou_labels)
        else:
            raise NotImplementedError

        iou_valid_mask = (rcnn_iou_labels >= 0).float()
        rcnn_loss_iou = (batch_loss_iou * iou_valid_mask).sum() / torch.clamp(iou_valid_mask.sum(), min=1.0)

        rcnn_loss_iou = rcnn_loss_iou * loss_cfgs.LOSS_WEIGHTS['rcnn_iou_weight']
        tb_dict = {'rcnn_loss_iou': rcnn_loss_iou.item()}
        return rcnn_loss_iou, tb_dict
