import bisect
import copy
import os
import random
import logging
import time
import math
import pickle
from contextlib import redirect_stdout
import torch.nn.functional as F
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import quadprog
import torch
from mmengine.model import detect_anomalous_params
# from scipy.optimize import linear_sum_assignment
from mmdet.structures import DetDataSample
from mmengine.optim import OptimWrapper
from mmengine.structures import InstanceData
from mmengine.runner import BaseLoop
from mmengine.runner.utils import calc_dynamic_intervals
from collections import defaultdict
from torch import Tensor
from torch.overrides import handle_torch_function
from torch.utils.data import DataLoader
from collections import OrderedDict
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.registry import LOOPS
from mmengine.utils import is_list_of


@LOOPS.register_module()
class IIKC(BaseLoop):
    def __init__(
            self,
            runner,
            dataloader: Union[DataLoader, Dict],
            max_epochs: int,
            is_use: bool=False,
            Lambda: int=1.0,
            val_begin: int = 0,
            val_interval: int = 1,
            EWC: bool=False,
            directory: str=None,
            dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
        super().__init__(runner, dataloader)
        self._max_epochs = int(max_epochs)
        assert self._max_epochs == max_epochs, \
            f'`max_epochs` should be a integer number, but get {max_epochs}.'
        self._max_iters = self._max_epochs * len(self.dataloader)
        self._epoch = 0
        self._iter = 0
        self.EWC=EWC
        self.directory=directory
        self.val_begin = val_begin
        self.val_interval = val_interval
        self.lamda = Lambda
        self.use = is_use
        self.loss_scale=1.0
        self.num_old_img=0
        self.new_neg=0
        self.num_rpn_bg_anchors=0
        self.num_rpn_anchors=0
        self.stop_training = False
        self.old_model = self.runner.model.module.get_old_model()
        if hasattr(self.dataloader.dataset, 'metainfo'):
            self.runner.visualizer.dataset_meta = \
                self.dataloader.dataset.metainfo
        else:
            print_log(
                f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
                'metainfo. ``dataset_meta`` in visualizer will be '
                'None.',
                logger='current',
                level=logging.WARNING)

        self.dynamic_milestones, self.dynamic_intervals = \
            calc_dynamic_intervals(
                self.val_interval, dynamic_intervals)

    @property
    def max_epochs(self):
        """int: Total epochs to train model."""
        return self._max_epochs

    @property
    def max_iters(self):
        """int: Total iterations to train model."""
        return self._max_iters

    @property
    def epoch(self):
        """int: Current epoch."""
        return self._epoch

    @property
    def iter(self):
        """int: Current iteration."""
        return self._iter
    
    def EWC_step(self, data: Union[dict, tuple, list],
                   optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
        with self.runner.optim_wrapper.optim_context(self):
            data = self.runner.model.module.data_preprocessor(data, training=True)   
            with torch.no_grad():
                losses,has_old_task_samples,pseudo_label_cls_weights_only,gt_pseudo_num,gt_new_num,self.task_num,self.ori_num_classes,self.num_classes = self.runner.model._run_forward(data, mode='loss')
        ori_model = self.runner.model.module
        self.old_model.backbone.load_state_dict(ori_model.backbone.state_dict(), strict=True)

        return None,None,None   
    
    def run(self,grad_squared_mean_loaded_after=None,run_before_ewc_path=None,old_task_ewc_path=None,task_num=None) -> torch.nn.Module:
        """Launch training."""
        self.runner.call_hook('before_train')
        ori_model = self.runner.model.module
        import pickle
        if (self.be_ewc or self.EWC):
            with open(old_task_ewc_path, "rb") as f:
                with redirect_stdout(None):
                    self.grad_squared_mean_loaded = pickle.load(f)
            self.num_old_all_img = 0
            self.num_old_all_anchor=0
            all_grad_squared_mean_changed = {}
            for i in range(task_num):
                if i>=1:
                    file_old_num_same_img = os.path.join(self.directory, f"grad_squared_mean_only_share_nonew_task_{i}.pkl")
                    if os.path.exists(file_old_num_same_img):
                        with open(file_old_num_same_img, "rb") as f:
                            with redirect_stdout(None):
                                num_all_same_img = pickle.load(f)
                            print(num_all_same_img.keys())
                            self.num_old_all_img -= num_all_same_img['img']
                file_old_num_all_img = os.path.join(self.directory, f"grad_squared_mean_prototype_task_{i}.pkl")
                if os.path.exists(file_old_num_all_img):
                    with open(file_old_num_all_img, "rb") as f:
                        with redirect_stdout(None):
                            num_all_img = pickle.load(f)
                        self.num_old_all_img += num_all_img['img']
                        self.num_old_all_anchor+=num_all_img['anchor']
                if i < task_num - 1:
                    if self.EWC and (i==0) and (task_num>=2):
                    # elif self.EWC and (i==1) and (task_num>=2):
                        grad_squared_mean_changed_path = os.path.join(self.directory, f"grad_squared_mean_changed_task_{task_num-2}.pkl")
                        print(grad_squared_mean_changed_path)
                        if os.path.exists(grad_squared_mean_changed_path):
                            with open(grad_squared_mean_changed_path, "rb") as f:
                                with redirect_stdout(None):
                                    grad_squared_mean_changed = pickle.load(f)
                            for key, key1 in zip(grad_squared_mean_changed.keys(), self.grad_squared_mean_loaded['all'].keys()):
                                # for key, value in grad_squared_mean_changed.items():
                                if key1 in all_grad_squared_mean_changed:
                                    raise KeyError("error")
                                    if i==0:
                                        all_grad_squared_mean_changed[key1] += grad_squared_mean_changed[key]
                                    else:
                                        all_grad_squared_mean_changed[key1] += grad_squared_mean_changed[key] 
                                else:
                                    if i==0:
                                        all_grad_squared_mean_changed[key1] = grad_squared_mean_changed[key]
                                    else:
                                        all_grad_squared_mean_changed[key1] = grad_squared_mean_changed[key]               

            print("Keys in all_grad_squared_mean_changed:", len(all_grad_squared_mean_changed.keys()))
            print("Keys in grad_squared_mean_loaded['all']:", len(self.grad_squared_mean_loaded['all'].keys()))
            for key in all_grad_squared_mean_changed:
                value=all_grad_squared_mean_changed[key]
                if key in self.grad_squared_mean_loaded['all']:
                    print("no erro")
                    self.grad_squared_mean_loaded['all'][key]+=value
                else:
                    print("erro")
                    self.grad_squared_mean_loaded['all'][key] = value 
            if grad_squared_mean_loaded_after is not None:
                self.grad_squared_mean_loaded_after=grad_squared_mean_loaded_after
            else:
                with open(run_before_ewc_path, "rb") as f:
                    with redirect_stdout(None):
                        self.grad_squared_mean_loaded_after = pickle.load(f)
            self.new_neg = self.grad_squared_mean_loaded_after['old_positive']
            self.num_old_img = self.grad_squared_mean_loaded_after['img']

        if self.EWC:
            for param_name, param in ori_model.named_parameters():
                if (
                    param_name in self.old_model.state_dict()
                    and param.requires_grad
                    and ('fc_cls' not in param_name)
                    and ('fc_reg' not in param_name)
                ):
                    current_grad_idx += 1
        if self.EWC:
            keys = list(self.grad_squared_mean_loaded['all'].keys())
            self.filtered_grad_squared = [keys[i] for i in grad_squared_indices]
            keys = list(self.grad_squared_mean_loaded_after['all'].keys())
            filtered_grad_squared_after = [
                keys[i] if i is not None else None
                for i in grad_ind_without_rpn_head
            ]
            self.k=(self.new_neg)/(512.0*(self.num_old_all_img)-self.new_neg)
            self.param_importance = {}
            all_max_importance_diffs = []
            for id,(importance_id,im_id_after) in enumerate(zip(self.filtered_grad_squared,filtered_grad_squared_after)):
                importance=self.grad_squared_mean_loaded['all'][importance_id].to(torch.device("cuda:0"))
                im_after=self.grad_squared_mean_loaded_after['all'][im_id_after]
                print(f"importance shape: {importance.shape}, im_after shape: {im_after.shape}")
                if importance_id is None:
                    raise ValueError(f"Old parameter for {param.name} not found in old model.")
                if im_after is not None:
                    importance_diff = torch.reciprocal(((self.k+1).pow(2))*torch.reciprocal(importance)+((self.k).pow(2))*torch.reciprocal(im_after))
                    importance_diff = torch.clamp(importance_diff, min=0)
                else:
                    importance_diff =importance
                self.param_importance[importance_id] = importance_diff
                all_max_importance_diffs.append(importance_diff.max())
                del importance ,importance_diff
            self.max_importance_diff = max(all_max_importance_diffs)
            file_after_ewc_grad_changed = os.path.join(self.directory, f"grad_squared_mean_changed_task_{task_num-1}.pkl")
        if self.EWC:
            for importance_id in self.filtered_grad_squared:
                importance_diff=self.param_importance[importance_id]
                importance=importance_diff/self.max_importance_diff
                self.param_importance[importance_id]=importance
        while self._epoch < self._max_epochs and not self.stop_training:
            self.run_epoch()
            self._decide_current_val_interval()
            if (self.runner.val_loop is not None
                    and self._epoch >= self.val_begin
                    and self._epoch % self.val_interval == 0):
                self.runner.val_loop.run()
        self.runner.call_hook('after_train')
        del self.runner.model.module.ori_model
        return self.runner.model

    def run_epoch(self) -> None:
        """Iterate one epoch."""
        self.runner.call_hook('before_train_epoch')
        self.runner.model.train()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter(idx, data_batch)
        self.runner.call_hook('after_train_epoch')
        self._epoch += 1

    def store_origin_grad(self, paras, grad_new_cls):
        for para, grad_new in zip(paras, grad_new_cls):
            if grad_new is not None:
                para.grad.data.copy_(grad_new)

    def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
        """Iterate one min-batch.

        Args:
            data_batch (Sequence[dict]): Batch of data from dataloader.
        """
        self.runner.call_hook(
            'before_train_iter', batch_idx=idx, data_batch=data_batch)
        with self.runner.optim_wrapper.optim_context(self):
            data = self.runner.model.module.data_preprocessor(data_batch, training=True) 
            self.device=data['inputs'].device
            losses = self.runner.model._run_forward(data, mode='loss',epoch=self._epoch)
        (loss_old_positive_cls, loss_old_positive_bbox,
         loss_new_positive_cls, loss_new_positive_bbox,
         loss_shared, log_vars) = self.runner.model.module.parse_losses_v3(losses)

        step_kwargs = {}
        zero_kwargs = {}

        loss_old_positive_cls = self.runner.optim_wrapper.scale_loss(loss_old_positive_cls)
        loss_old_positive_bbox = self.runner.optim_wrapper.scale_loss(loss_old_positive_bbox)
        loss_new_positive_cls = self.runner.optim_wrapper.scale_loss(loss_new_positive_cls)
        loss_new_positive_bbox = self.runner.optim_wrapper.scale_loss(loss_new_positive_bbox)
        loss_shared = self.runner.optim_wrapper.scale_loss(loss_shared)
        ori_model = self.runner.model.module
        if self.EWC:
            self.runner.optim_wrapper.zero_grad(**zero_kwargs)
            shared_params = []
            paras = []
            ewc_loss = 0.0
            for para in ori_model.parameters():
                if para.requires_grad:
                    paras.append(para)
            for param_name, param in ori_model.named_parameters():
                if (
                    param_name in self.old_model.state_dict()
                    and param.requires_grad
                    and ('fc_cls' not in param_name)
                    and ('fc_reg' not in param_name)
                ):
                    shared_params.append(param)
            for param, importance_id in zip(shared_params, self.filtered_grad_squared):
                param_value=param
                ewc_loss += ((self.a_ewc) * (importance) * (param_value - importance_id).pow(2)).sum()
            self.runner.optim_wrapper.zero_grad(**zero_kwargs)
            grad_ewcs = torch.autograd.grad(ewc_loss, paras, retain_graph=True, allow_unused=True)
            self.runner.optim_wrapper.backward(
                loss_old_positive_cls + loss_old_positive_bbox + loss_new_positive_cls + loss_new_positive_bbox + loss_shared #+ewc_loss#+semantic_loss
            )
            log_vars['ewc_loss'] = ewc_loss
            if self._iter % 10 == 0:
                print(f"ewc_loss:{ewc_loss.item()}")
            grad_ewcs_filtered = [grad for grad in grad_ewcs if grad is not None]

            for para, grad_ewc,importance_id in zip(shared_params, grad_ewcs_filtered,self.filtered_grad_squared):
                if grad_ewc is not None:
                    importance = self.param_importance.get(importance_id, 0.0)                    
                    adaptive_threshold = clip_threshold
                    grad_norm = torch.norm(grad_ewc)
                    if grad_norm==0:
                        grad_norm=grad_norm+1e-8
                    if grad_norm > adaptive_threshold:
                        scale_factor = adaptive_threshold / (grad_norm)
                        grad_ewc.mul_(scale_factor)
                    para.grad.data.copy_(grad_ewc + para.grad)
                else:
                    para.grad.data.copy_(para.grad)
        if self.runner.optim_wrapper.should_update():
            # self.runner.optim_wrapper.step(**step_kwargs)
            if self.runner.optim_wrapper.clip_grad_kwargs:
                self.runner.optim_wrapper._clip_grad()
            self.runner.optim_wrapper.optimizer.step(**step_kwargs)
            self.runner.optim_wrapper.zero_grad(**zero_kwargs)
        # -------------------------------------------------------
        if self.runner.model.detect_anomalous_params:
            detect_anomalous_params(loss_old_positive_cls + loss_old_positive_bbox
                                    + loss_new_positive_cls + loss_new_positive_bbox
                                    + loss_shared, model=self)

        self.runner.call_hook(
            'after_train_iter',
            batch_idx=idx,
            data_batch=data_batch,
            outputs=log_vars)
        self._iter += 1

    def _decide_current_val_interval(self) -> None:
        """Dynamically modify the ``val_interval``."""
        step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
        self.val_interval = self.dynamic_intervals[step - 1]

    def run_fc_cls(self,old_task_ewc_path=None,ori_setting=None,task_num=1) -> None:
        """Launch training."""
        self.runner.call_hook('before_train')
        self.runner.call_hook('before_train_epoch')
        for i in range(task_num):
            proto_path = os.path.join(self.directory, f"grad_squared_mean_prototype_task_{i}.pkl")
            with open(proto_path, "rb") as f:
                with redirect_stdout(None):
                    old_task = pickle.load(f)
                protos = old_task["protos"]
                covs = old_task["covs"]
                radiuses = old_task["radiuses"]
            background_protos_expanded = bg_proto
            background_covs_expanded = covs_bg
            background_radiuses_expanded = radiuses_bg
            background_protos = background_protos_expanded if background_protos is None else np.concatenate([background_protos, background_protos_expanded], axis=0)
            background_covs = background_covs_expanded if background_covs is None else np.concatenate([background_covs, background_covs_expanded], axis=0)
            background_radiuses = background_radiuses_expanded if background_radiuses is None else np.concatenate([background_radiuses, background_radiuses_expanded], axis=0)

            all_protos_no_bg = protos_no_bg if all_protos_no_bg is None else np.concatenate([all_protos_no_bg, protos_no_bg], axis=0)
            all_covs_no_bg = covs_no_bg if all_covs_no_bg is None else np.concatenate([all_covs_no_bg, covs_no_bg], axis=0)
            all_radiuses_no_bg = radiuses_no_bg if all_radiuses_no_bg is None else np.concatenate([all_radiuses_no_bg, radiuses_no_bg], axis=0)

        self.len_class_old=len(all_protos_no_bg)+len(background_protos)
        self._protos_old = np.concatenate([all_protos_no_bg, background_protos], axis=0)
        self._covs_old = np.concatenate([all_covs_no_bg, background_covs], axis=0)
        self._radiuses_old = np.concatenate([all_radiuses_no_bg, background_radiuses], axis=0)
        self._radius_old = np.sqrt(np.mean(self._radiuses_old))
        self._radius_old_f = np.sqrt(np.mean(all_radiuses_no_bg))
        self._radius_old_b=np.sqrt(np.mean(background_radiuses))
        del old_task
        for param in self.runner.model.parameters():
            param.requires_grad = False
        for param in self.runner.model.module.roi_head.bbox_head.fc_cls.parameters():
            param.requires_grad = True
        for param in self.runner.model.module.protos_linear.parameters():
            param.requires_grad = True
        self.fc_cls = self.runner.model.module.roi_head.bbox_head.fc_cls
        while self._epoch < self._max_epochs and not self.stop_training:
            self.run_fc_cls_epoch()

            self._decide_current_val_interval()
            if (self.runner.val_loop is not None
                    and self._epoch >= self.val_begin
                    and self._epoch % self.val_interval == 0):
                self.runner.val_loop.run()

        self.runner.call_hook('after_train')
        return self.runner.model

    def run_fc_cls_epoch(self) -> None:
        """Iterate one epoch."""
        self.runner.call_hook('before_train_epoch')
        self.runner.model.train()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter_fc_cls(idx,data_batch)
        self.runner.call_hook('after_train_epoch')
        self._epoch += 1
    def run_iter_fc_cls(self,idx, data_batch: Sequence[dict]) -> None:
        """Iterate one mini-batch.

        Args:
            data_batch (Sequence[dict]): Batch of data from dataloader.
        """
        self.runner.call_hook(
            'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
        with self.runner.optim_wrapper.optim_context(self):
            data = self.runner.model.module.data_preprocessor(data_batch, training=True) 
            self.device=data['inputs'].device
            with torch.no_grad():
                losses = self.runner.model._run_forward(data, mode='loss',epoch=self._epoch)
        step_kwargs = {}
        zero_kwargs = {}
        log_vars={}
        labels_gt=[]
        proto_feats=losses['bbox_feats']
        proto_targets_new=losses['bbox_targets']
        self.old_model.eval()
        with torch.no_grad():
            losses,_ = self.old_model.loss(
                    batch_inputs=data['inputs'],
                    batch_data_samples=data['data_samples'],
                    num_class=self.num_classes
                )
            
        proto_old_feats=losses['bbox_feats']
        proto_old_targets=losses['bbox_targets']
        proto_transfer=self.runner.model.module.protos_linear(proto_old_feats)
        for data_sample in data['data_samples']:

            sample_labels = data_sample.gt_instances['labels']
            labels_gt.append(sample_labels)
        labels_gt=torch.cat(labels_gt,dim=0)
        unique_labels, counts = labels_gt.unique(return_counts=True)

        p_b_f=(new_bg)/(new_bg+new_pos+old_pos)
        p_f_b=(new_pos+old_pos)/(new_bg+new_pos+old_pos)
        N_old_pos=2*self.ori_num_classes
        N_total_old=round(N_old_pos/p_f_b)

        loss_transfer = self.match_and_l2loss(proto_feats, proto_targets_new, proto_transfer, proto_old_targets,unique_labels, counts,count_bg=round(torch.sum(counts).item()/p_b_f))
        log_vars['loss_transfer'] = loss_transfer

        classes = np.arange(self.len_class_old)
        p_fg = np.ones(self.ori_num_classes)* (p_f_b) /( self.ori_num_classes)
        p_ng=np.ones(self.len_class_old-self.ori_num_classes)* (p_b_f) /( (self.len_class_old-self.ori_num_classes))
        p = np.concatenate([p_fg, p_ng],axis=0)
        p /= p.sum()

        index = np.random.choice(classes, size=N_total_old, replace=True,p=p)
        proto_features_raw = np.array(self._protos_old)[index]
        radius_old=[self._radiuses_old[id] for id in index]
        index[index >=self.ori_num_classes]=self.num_classes
        proto_targets = index
        proto_features = proto_features_raw + np.random.normal(0,1,proto_features_raw.shape)*np.sqrt(radius_old)[:, np.newaxis]

        proto_features = torch.from_numpy(proto_features).float().to(self.device,non_blocking=True)
        proto_targets = torch.from_numpy(proto_targets).to(self.device,non_blocking=True)

        proto_features_transfer = self.runner.model.module.protos_linear(proto_features).detach().clone()

        proto_features_transfer=torch.cat([proto_features_transfer, proto_feats], dim=0)
        proto_targets=torch.cat([proto_targets, proto_targets_new], dim=0)
        cls_scores = self.fc_cls(proto_features_transfer)
        loss_cls = F.cross_entropy(cls_scores, proto_targets)

        self.runner.optim_wrapper.zero_grad(**zero_kwargs)
        self.runner.optim_wrapper.backward(loss_cls+loss_transfer)
        if self.runner.optim_wrapper.should_update():
            if self.runner.optim_wrapper.clip_grad_kwargs:
                self.runner.optim_wrapper._clip_grad()
            self.runner.optim_wrapper.optimizer.step(**step_kwargs)
            self.runner.optim_wrapper.zero_grad(**zero_kwargs)

        # -------------------------------------------------------
        if self.runner.model.detect_anomalous_params:
            detect_anomalous_params(loss_cls+loss_transfer, model=self)
        self.runner.call_hook(
            'after_train_iter',
            batch_idx=idx,
            data_batch=data_batch,
            outputs=log_vars)
        self._iter += 1