# Copyright (c) OpenMMLab. All rights reserved.
import copy
import json
import math
import os
from collections import OrderedDict
from typing import Dict, Tuple

import cv2
import numpy as np
import torch
from mmengine import Config, is_list_of
from mmengine.runner import load_checkpoint, load_state_dict
from mmengine.structures import InstanceData
from scipy.stats import norm
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
from torch import Tensor, nn
from tqdm import tqdm

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from tools.analysis_tools.robustness_eval import print_coco_results
from .base import ForwardResults
from .two_stage import TwoStageDetector
from ...apis import inference_detector, init_detector
from ...structures import SampleList, OptSampleList
import einops

@MODELS.register_module()
class FasterRCNNIncrease(TwoStageDetector):
    """Implementation of `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_"""

    def __init__(self,
                 pseudo_label_setting: ConfigType,
                 ori_setting: ConfigType,
                 current_dataset_setting: ConfigType,
                 backbone: ConfigType,
                 rpn_head: ConfigType,
                 roi_head: ConfigType,
                 train_cfg: ConfigType,
                 test_cfg: ConfigType,
                 neck: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor)
        self.ori_setting=ori_setting
        self.num_classes = ori_setting['num_classes']
        self.task_num = ori_setting.pop('task_num', None)
        if self.task_num is not None and self.task_num>1:
            self.base_config_file=ori_setting.pop('base_config_file', None)
        else:
            self.base_config_file=None
        self.old_num_classes=ori_setting['ori_num_classes']
        self.ori_num_classes = None
        self.load_base_detector(ori_setting)
        self.iou_threshold = 0.8

    def get_old_model(self):
        cfg_path = self.ori_setting['ori_config_file']
        pt_path = self.ori_setting['ori_checkpoint_file']
        old_model = init_detector(cfg_path, pt_path,base_config_file=self.base_config_file)
        return old_model
    def calculate_iou(self, box1, box2):
        x1_1, y1_1, x2_1, y2_1 = box1.cpu().tolist()
        x1_2, y1_2, x2_2, y2_2 = box2.cpu().tolist()
        area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
        x_left = max(x1_1, x1_2)
        y_top = max(y1_1, y1_2)
        x_right = min(x2_1, x2_2)
        y_bottom = min(y2_1, y2_2)
        if x_right < x_left or y_bottom < y_top:
            return 0.0
        intersection_area = (x_right - x_left) * (y_bottom - y_top)
        union_area = area1 + area2 - intersection_area
        iou = intersection_area / union_area
        return iou

    def loss(self, batch_inputs: Tensor,epoch:int,
             batch_data_samples: SampleList,has_old_task_samples_all=None,\
                            pseudo_label_cls_weights_only_all=None,gt_pseudo_num_all=None,gt_new_num_all=None,only_old_sign=False,only_new_sign=False,task_id=None) -> dict:
        if ((batch_inputs is not None) and (only_old_sign==False)) or (only_old_sign and only_new_sign):
            x = self.extract_feat(batch_inputs)
            gt_pseudo_num = []
            gt_new_num = []
            result_ori = self.ori_model.predict(batch_inputs, copy.deepcopy(batch_data_samples), rescale=False)
            pseudo_label_cls_weights_list = []
            pseudo_label_reg_weights_list = []
            has_old_task_samples=[]
            if epoch==-2:
                if self.use_weight_pseudo_label:
                    for result, batch_data_sample,batch_input in zip(result_ori, batch_data_samples,batch_inputs):
                        pseudo_label_cls_weights = []
                        pseudo_label_reg_weights = []
                        for bbox_index, bbox in enumerate(result.pred_instances):
                            confidence = bbox['scores'].item()
                            if confidence < self.negative_threshold:
                                continue
                            elif confidence > self.positive_threshold:
                                filter_tag = 0
                                for ground_true_bbox in batch_data_sample.gt_instances['bboxes']:
                                    iou = self.calculate_iou(ground_true_bbox, bbox['bboxes'][0])
                                    if iou > self.iou_threshold:
                                        filter_tag = 1
                                        break
                                if filter_tag == 1:
                                    continue
                                cls_weight = 1.
                                reg_weight = 1.
                                pseudo_label_cls_weights.append(cls_weight)
                                pseudo_label_reg_weights.append(reg_weight)
                                bbox.__delattr__('scores')
                                # pseudo_box=
                                batch_data_sample.gt_instances = batch_data_sample.gt_instances.cat(
                                    [batch_data_sample.gt_instances, bbox])
                                batch_data_sample.img_tensor=batch_input                               
                        if epoch==-2:
                            has_old_task_samples.append(batch_data_sample)
                        pseudo_label_cls_weights_list.append(pseudo_label_cls_weights)
                        pseudo_label_reg_weights_list.append(pseudo_label_reg_weights)
                        gt_pseudo_num.append(len(pseudo_label_reg_weights))
                        gt_new_num.append(len(batch_data_sample.gt_instances) - len(pseudo_label_reg_weights))
                else:
                    for result, batch_data_sample,batch_input in zip(result_ori, batch_data_samples,batch_inputs):
                        indices = cv2.dnn.NMSBoxes(
                            result.pred_instances['bboxes'].cpu().numpy(),
                            result.pred_instances['scores'].cpu().numpy(),
                            0.,
                            self.iou_threshold)
                        pseudo_label_cls_weights = []
                        pseudo_label_reg_weights = []
                        for bbox_index, bbox in enumerate(result.pred_instances):
                            if bbox['scores'] < 0.7:  #0.5
                                continue
                            if bbox_index not in indices:
                                continue
                            max_iou =0
                            for ground_true_bbox in batch_data_sample.gt_instances['bboxes']:
                                iou = self.calculate_iou(ground_true_bbox, bbox['bboxes'][0])
                                if iou > max_iou:
                                    max_iou = iou
                            if max_iou > 0.8:
                                continue
                            else:
                                bbox.__delattr__('scores')
                                batch_data_sample.img_tensor=batch_input
                                pseudo_label_cls_weights.append(1.0)
                                pseudo_label_reg_weights.append(1.0)
                                batch_data_sample.gt_instances = batch_data_sample.gt_instances.cat(
                                    [batch_data_sample.gt_instances, bbox])
                        if epoch==-2:
                            has_old_task_samples.append(batch_data_sample)
                        pseudo_label_reg_weights_list.append(pseudo_label_reg_weights)
                        pseudo_label_cls_weights_list.append(pseudo_label_cls_weights)
                        gt_pseudo_num.append(len(pseudo_label_reg_weights))
                        gt_new_num.append(len(batch_data_sample.gt_instances) - len(pseudo_label_reg_weights))
            elif epoch!=-2:
                if self.use_weight_pseudo_label:
                    for result, batch_data_sample in zip(result_ori, batch_data_samples):
                        pseudo_label_cls_weights = []
                        pseudo_label_reg_weights = []
                        for bbox_index, bbox in enumerate(result.pred_instances):
                            confidence = bbox['scores'].item()
                            if confidence < self.negative_threshold:
                                continue
                            elif confidence > self.positive_threshold:
                                filter_tag = 0
                                for ground_true_bbox in batch_data_sample.gt_instances['bboxes']:
                                    iou = self.calculate_iou(ground_true_bbox, bbox['bboxes'][0])
                                    if iou > self.iou_threshold:
                                        filter_tag = 1
                                        break
                                if filter_tag == 1:
                                    continue
                                cls_weight = 1.   
                                reg_weight = 1.
                                pseudo_label_cls_weights.append(cls_weight)
                                pseudo_label_reg_weights.append(reg_weight)
                                bbox.__delattr__('scores')
                                batch_data_sample.gt_instances = batch_data_sample.gt_instances.cat(
                                    [batch_data_sample.gt_instances, bbox])
                        pseudo_label_cls_weights_list.append(pseudo_label_cls_weights)
                        pseudo_label_reg_weights_list.append(pseudo_label_reg_weights)
                        gt_pseudo_num.append(len(pseudo_label_reg_weights))
                        gt_new_num.append(len(batch_data_sample.gt_instances) - len(pseudo_label_reg_weights))
                else:
                    for result, batch_data_sample in zip(result_ori, batch_data_samples):
                        indices = cv2.dnn.NMSBoxes(
                            result.pred_instances['bboxes'].cpu().numpy(),
                            result.pred_instances['scores'].cpu().numpy(),
                            0.,
                            self.iou_threshold)
                        pseudo_label_cls_weights = []
                        pseudo_label_reg_weights = []
                        for bbox_index, bbox in enumerate(result.pred_instances):
                            if bbox['scores'] < 0.7:
                                continue
                            if bbox_index not in indices:
                                continue
                            max_iou =0
                            for ground_true_bbox in batch_data_sample.gt_instances['bboxes']:
                                iou = self.calculate_iou(ground_true_bbox, bbox['bboxes'][0])
                                if iou > max_iou:
                                    max_iou = iou
                            if max_iou > 0.8:
                                continue
                            else:
                                bbox.__delattr__('scores')
                                pseudo_label_cls_weights.append(1.0)
                                pseudo_label_reg_weights.append(1.0)
                                batch_data_sample.gt_instances = batch_data_sample.gt_instances.cat(
                                    [batch_data_sample.gt_instances, bbox])
                        pseudo_label_reg_weights_list.append(pseudo_label_reg_weights)
                        pseudo_label_cls_weights_list.append(pseudo_label_cls_weights)
                        gt_pseudo_num.append(len(pseudo_label_reg_weights))
                        gt_new_num.append(len(batch_data_sample.gt_instances) - len(pseudo_label_reg_weights))
            losses = dict()
            if self.with_rpn:
                proposal_cfg = self.train_cfg.get('rpn_proposal',
                                                self.test_cfg.rpn)
                rpn_data_samples = copy.deepcopy(batch_data_samples)
                for data_sample in rpn_data_samples:
                    data_sample.gt_instances.labels = \
                        torch.zeros_like(data_sample.gt_instances.labels)

                rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
                    x, rpn_data_samples, proposal_cfg=proposal_cfg)
                keys = rpn_losses.keys()
                for key in list(keys):
                    if 'loss' in key and 'rpn' not in key:
                        rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
                losses.update(rpn_losses)
            else:
                assert batch_data_samples[0].get('proposals', None) is not None
                rpn_results_list = [
                    data_sample.proposals for data_sample in batch_data_samples
                ]
            roi_losses = self.roi_head.loss_weight(x, rpn_results_list,
                                                batch_data_samples,
                                                pseudo_label_cls_weights_list,
                                                pseudo_label_reg_weights_list,
                                                gt_pseudo_num,
                                                gt_new_num,
                                                self.ori_num_classes,
                                                self.num_classes,
                                                self.task_num,
                                                task_id,
                                                )
            del roi_losses['labels']
            losses.update(roi_losses)
            losses['bbox_feats']=roi_losses['bbox_feats']
            losses['bbox_targets']=roi_losses['bbox_targets']
        else:
            losses = dict()
            img_inputs = [data_sample.img_tensor for data_sample in has_old_task_samples_all]
            img_inputs = torch.stack(img_inputs)
            x = self.extract_feat(img_inputs)
            if self.with_rpn:
                proposal_cfg = self.train_cfg.get('rpn_proposal',
                                                self.test_cfg.rpn)
                rpn_data_samples = copy.deepcopy(has_old_task_samples_all)
                # set cat_id of gt_labels to 0 in RPN
                for data_sample in rpn_data_samples:
                    data_sample.gt_instances.labels = \
                        torch.zeros_like(data_sample.gt_instances.labels)

                rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
                    x, rpn_data_samples, proposal_cfg=proposal_cfg)
                # avoid get same name with roi_head loss
                keys = rpn_losses.keys()
                for key in list(keys):
                    if 'loss' in key and 'rpn' not in key:
                        rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
                losses.update(rpn_losses)
            else:
                assert has_old_task_samples_all[0].get('proposals', None) is not None
                # use pre-defined proposals in InstanceData for the second stage
                # to extract ROI features.
                rpn_results_list = [
                    data_sample.proposals for data_sample in has_old_task_samples_all
                ]

            roi_losses = self.roi_head.loss_weight(x, rpn_results_list,
                                                has_old_task_samples_all,
                                                pseudo_label_cls_weights_only_all,
                                                pseudo_label_cls_weights_only_all,
                                                gt_pseudo_num_all,
                                                gt_new_num_all,
                                                self.ori_num_classes,
                                                self.num_classes,
                                                self.task_num,
                                                task_id,
                                                )
            
            losses.update(roi_losses)
        if epoch==-2: 
            has_pseudo_inds=[i for i, sublist in enumerate(pseudo_label_reg_weights_list) if sublist]
            has_old_task_samples=[has_old_task_samples[i] for i in has_pseudo_inds]
            pseudo_label_cls_weights_only=[pseudo_label_cls_weights_list[i] for i in has_pseudo_inds]
            gt_pseudo_num_only=[gt_pseudo_num[i] for i in has_pseudo_inds]
            gt_new_num_only=[gt_new_num[i] for i in has_pseudo_inds]

            return losses,has_old_task_samples,pseudo_label_cls_weights_only,gt_pseudo_num_only,gt_new_num_only,self.task_num,self.ori_num_classes,self.num_classes
        else:
            return losses
    def parse_losses_v3(
        self, losses: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:

        log_vars = []
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars.append([loss_name, loss_value.mean()])
            elif is_list_of(loss_value, torch.Tensor):
                log_vars.append(
                    [loss_name,
                     sum(_loss.mean() for _loss in loss_value)])
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss_old_positive_cls = sum(value for key, value in log_vars if ('loss' in key and 'old_positive' in key and 'cls' in key))
        loss_old_positive_bbox = sum(value for key, value in log_vars if ('loss' in key and 'old_positive' in key and 'bbox' in key))
        loss_new_positive_cls = sum(value for key, value in log_vars if ('loss' in key and 'new_positive' in key and 'cls' in key))
        loss_new_positive_bbox = sum(value for key, value in log_vars if ('loss' in key and 'new_positive' in key and 'bbox' in key))
        loss_shared = sum(value for key, value in log_vars if (('loss' in key)
                                                               and ('old_positive' not in key)
                                                               and ('new_positive' not in key)))
        loss = sum(value for key, value in log_vars if 'loss' in key)
        log_vars.insert(0, ['loss', loss])
        log_vars = OrderedDict(log_vars)  # type: ignore
        return (loss_old_positive_cls, loss_old_positive_bbox,
                loss_new_positive_cls, loss_new_positive_bbox,
                loss_shared, log_vars)  # type: ignore
    def get_positive_and_negative_score(self, output, class_index):
        positive_index, negative_index = self.positive_index[class_index]
        return output[0, positive_index], output[0, negative_index]



    def load_base_detector(self, ori_setting):
        """
                Initialize detector from config file.
        :param ori_setting:
        :return:
        """
        assert os.path.isfile(ori_setting['ori_checkpoint_file']), '{} is not a valid file'.format(
            ori_setting['ori_checkpoint_file'])

        ori_cfg = Config.fromfile(ori_setting['ori_config_file'])
        if hasattr(ori_cfg.model, 'latest_model_flag'):
            ori_cfg.model.latest_model_flag = False
        ori_model = MODELS.build(ori_cfg.model)
        # load checkpoint
        load_checkpoint(ori_model, ori_setting.ori_checkpoint_file, strict=False)
        # # set to eval mode
        ori_model.eval()
        ori_model.ori_model = None
        # ori_model.forward = ori_model.forward_dummy
        # # set requires_grad of all parameters to False
        for param in ori_model.parameters():
            param.requires_grad = False

        # ##### init original branchs of new model #####
        self.ori_num_classes = ori_setting.ori_num_classes
        self._load_checkpoint_for_new_model(ori_setting.ori_checkpoint_file, strict=False)
        self.ori_model = ori_model

    def _load_checkpoint_for_new_model(self, checkpoint_file, map_location=None, strict=True, logger=None):
        # load ckpt
        checkpoint = torch.load(checkpoint_file, map_location=map_location)
        # get state_dict from checkpoint
        if isinstance(checkpoint, OrderedDict):
            state_dict = checkpoint
        elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            raise RuntimeError(
                'No state_dict found in checkpoint file {}'.format(checkpoint_file))
        # strip prefix of state_dict
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[7:]: v for k,
                                       v in checkpoint['state_dict'].items()}
        # modify cls head size of state_dict
        added_branch_weight = (torch.mean(state_dict['roi_head.bbox_head.fc_cls.weight'][:self.ori_num_classes], dim=0, keepdim=True)).expand(self.num_classes - self.ori_num_classes, -1)

        added_branch_bias = (torch.mean(state_dict['roi_head.bbox_head.fc_cls.bias'][:self.ori_num_classes], dim=0, keepdim=True)).expand(self.num_classes - self.ori_num_classes)

        added_branch_weight = added_branch_weight.expand(self.num_classes - self.ori_num_classes, -1)
        added_branch_bias.expand(self.num_classes - self.ori_num_classes)
        state_dict['roi_head.bbox_head.fc_cls.weight'] = torch.cat(
            (
                state_dict['roi_head.bbox_head.fc_cls.weight'][:self.ori_num_classes],
                added_branch_weight,
                state_dict['roi_head.bbox_head.fc_cls.weight'][-1:]
            ), dim=0)
        state_dict['roi_head.bbox_head.fc_cls.bias'] = torch.cat(
            (
                state_dict['roi_head.bbox_head.fc_cls.bias'][:self.ori_num_classes],
                added_branch_bias,
                state_dict['roi_head.bbox_head.fc_cls.bias'][-1:]
            ), dim=0)

        # modify reg head size of state_dict
        added_branch_weight = state_dict['roi_head.bbox_head.fc_reg.weight']
        added_branch_bias = state_dict['roi_head.bbox_head.fc_reg.bias']
        added_branch_weight = torch.mean(einops.rearrange(added_branch_weight, '(n m) d-> n m d', m=4), dim=0, keepdim=True)
        added_branch_bias = torch.mean(einops.rearrange(added_branch_bias, '(n m)-> n m', m=4),dim=0, keepdim=True)

        added_branch_weight = added_branch_weight.expand(self.num_classes - self.ori_num_classes, -1, -1)
        added_branch_bias = added_branch_bias.expand(self.num_classes - self.ori_num_classes, -1)

        added_branch_weight = einops.rearrange(added_branch_weight, 'n m d -> (n m) d')
        added_branch_bias = einops.rearrange(added_branch_bias, 'n m-> (n m)')

        state_dict['roi_head.bbox_head.fc_reg.weight'] = torch.cat(
            (state_dict['roi_head.bbox_head.fc_reg.weight'], added_branch_weight), dim=0)
        state_dict['roi_head.bbox_head.fc_reg.bias'] = torch.cat(
            (state_dict['roi_head.bbox_head.fc_reg.bias'], added_branch_bias), dim=0)

        state_dict_new = {}
        for k, v in state_dict.items():
            if k.startswith('ori_model'):
                continue
            state_dict_new[k] = v

        # load state_dict
        if hasattr(self, 'module'):
            load_state_dict(self.module, state_dict_new, strict, logger)
        else:
            load_state_dict(self, state_dict_new, strict, logger)

    def predict(self,
                batch_inputs: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = True) -> SampleList:
        # assert 1 < 0
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            batch_inputs (Tensor): Inputs with shape (N, C, H, W).
            batch_data_samples (List[:obj:`DetDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
            rescale (bool): Whether to rescale the results.
                Defaults to True.

        Returns:
            list[:obj:`DetDataSample`]: Return the detection results of the
            input images. The returns value is DetDataSample,
            which usually contain 'pred_instances'. And the
            ``pred_instances`` usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                    the last dimension 4 arrange as (x1, y1, x2, y2).
                - masks (Tensor): Has a shape (num_instances, H, W).
        """
        assert self.with_bbox, 'Bbox head must be implemented.'
        x = self.extract_feat(batch_inputs)

        # If there are no pre-defined proposals, use RPN to get proposals
        if batch_data_samples[0].get('proposals', None) is None:
            rpn_results_list = self.rpn_head.predict(
                x, batch_data_samples, rescale=False)
        else:
            rpn_results_list = [
                data_sample.proposals for data_sample in batch_data_samples
            ]

        results_list = self.roi_head.predict(
            x, rpn_results_list, batch_data_samples, rescale=rescale)

        batch_data_samples = self.add_pred_to_datasample(
            batch_data_samples, results_list)
        return batch_data_samples

    def forward(self,
                inputs: torch.Tensor,
                data_samples: OptSampleList = None,
                mode: str = 'tensor',epoch=None,has_old_task_samples_all=None,pseudo_label_cls_weights_only_all=None,gt_pseudo_num_all=None,gt_new_num_all=None,only_new_sign=False,only_old_sign=False) -> ForwardResults:
        """The unified entry for a forward process in both training and test.

        The method should accept three modes: "tensor", "predict" and "loss":

        - "tensor": Forward the whole network and return tensor or tuple of
        tensor without any post-processing, same as a common nn.Module.
        - "predict": Forward and return the predictions, which are fully
        processed to a list of :obj:`DetDataSample`.
        - "loss": Forward and return a dict of losses according to the given
        inputs and data samples.

        Note that this method doesn't handle either back propagation or
        parameter update, which are supposed to be done in :meth:`train_step`.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (list[:obj:`DetDataSample`], optional): A batch of
                data samples that contain annotations and predictions.
                Defaults to None.
            mode (str): Return what kind of value. Defaults to 'tensor'.

        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of :obj:`DetDataSample`.
            - If ``mode="loss"``, return a dict of tensor.
        """
        # assert 1 < 0
        if mode == 'loss':
            return self.loss(inputs,epoch, data_samples,has_old_task_samples_all,\
                            pseudo_label_cls_weights_only_all,gt_pseudo_num_all,gt_new_num_all,only_new_sign=only_new_sign,only_old_sign=only_old_sign)
        elif mode == 'predict':
            return self.predict(inputs, data_samples)
        elif mode == 'tensor':
            return self._forward(inputs, data_samples)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')

