from typing import Dict, List, Optional, Union

import torch
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from mmengine.structures import BaseDataElement
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS
from ...base import BaseAlgorithm, LossResults


# cascadingly apply mask to different feature maps in teacher model and get the masked loss of the downstream tasks.
@MODELS.register_module()
class BEVQueryGuidedDistillCascadeTeacherAssist(BaseAlgorithm):
    """``SingleTeacherDistill`` can be used to develop distill algorithms which
    only use one teacher.

    Args:
        distiller (dict): The config dict for built distiller.
        teacher (dict | BaseModel): The config dict for teacher model or built
            teacher model.
        teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
        teacher_trainable (bool): Whether the teacher is trainable. Defaults
            to False.
        teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
            mode, namely, freeze running stats (mean and var). Note: Effect on
            Batch Norm and its variants only. Defaults to True.
        student_trainable (bool): Whether the student is trainable. Defaults
            to True.
        calculate_student_loss (bool): Whether to calculate student loss
            (original task loss) to update student model. Defaults to True.
        teacher_module_inplace(bool): Whether to allow teacher module inplace
            attribute True. Defaults to False.
        新增参数cascade_prompt: 传入控制调用teacher哪个层的string prompt，必须传入
        一般选项是：before_pts_backbone, before_pts_neck, before_head -- 其实还可以拓展别的层，先实现3个锚点
        目前不太好做过pts_neck之前的feature蒸馏，因为这里是一个tuple，且tuple中的两个值是不同的H W，和bevquery不兼容，或者做但是只做兼容的H, W
    """

    def __init__(self,
                 distiller: dict,
                 teacher: Union[BaseModel, Dict],
                 cascade_prompt: List[str],
                 teacher_ckpt: Optional[str] = None,
                 teacher_trainable: bool = False,
                 teacher_norm_eval: bool = True,
                 student_trainable: bool = True,
                 calculate_student_loss: bool = True,
                 teacher_module_inplace: bool = False,
                mask_learning_stopped: bool = False,
                 **kwargs) -> None:
        super().__init__(**kwargs)

        self.distiller = MODELS.build(distiller)

        if isinstance(teacher, Dict):
            teacher = MODELS.build(teacher)

        if not isinstance(teacher, BaseModel):
            raise TypeError('teacher should be a `dict` or '
                            f'`BaseModel` instance, but got '
                            f'{type(teacher)}')

        self.teacher = teacher

        # Find all nn.Modules in the model that contain the 'inplace' attribute
        # and set them to False.
        self.teacher_module_inplace = teacher_module_inplace
        if not self.teacher_module_inplace:
            self.set_module_inplace_false(teacher, 'self.teacher')

        if teacher_ckpt:
            _ = load_checkpoint(self.teacher, teacher_ckpt)
            # avoid loaded parameters be overwritten
            self.teacher._is_init = True
        self.teacher_trainable = teacher_trainable
        if not self.teacher_trainable:
            for param in self.teacher.parameters():
                param.requires_grad = False
        self.teacher_norm_eval = teacher_norm_eval

        # The student model will not calculate gradients and update parameters
        # in some pretraining process.
        self.student_trainable = student_trainable

        # The student loss will not be updated into ``losses`` in some
        # pretraining process.
        self.calculate_student_loss = calculate_student_loss

        # In ``ConfigurableDistller``, the recorder manager is just
        # constructed, but not really initialized yet.
        self.distiller.prepare_from_student(self.student)
        self.distiller.prepare_from_teacher(self.teacher)

        # may be modified by stop distillation hook
        self.distillation_stopped = False
        self.distiller.mask_learning_stopped = mask_learning_stopped

        self.cascade_prompt = cascade_prompt


    @property
    def student(self) -> nn.Module:
        """Alias for ``architecture``."""
        return self.architecture

    def loss(
        self,
        batch_inputs: torch.Tensor,
        data_samples: Optional[List[BaseDataElement]] = None,
    ) -> LossResults:
        """Calculate losses from a batch of inputs and data samples."""

        losses = dict()

        # If the `override_data` of a delivery is False, the delivery will
        # record the origin data.
        self.distiller.set_deliveries_override(False)
        if self.teacher_trainable:
            with self.distiller.teacher_recorders, self.distiller.deliveries:
                teacher_losses = self.teacher(
                    batch_inputs, data_samples, mode='loss')

            losses.update(add_prefix(teacher_losses, 'teacher'))
        else:
            with self.distiller.teacher_recorders, self.distiller.deliveries:
                with torch.no_grad():
                    _ = self.teacher(batch_inputs, data_samples, mode='loss')

        # If the `override_data` of a delivery is True, the delivery will
        # override the origin data with the recorded data.
        self.distiller.set_deliveries_override(True)
        # Original task loss will not be used during some pretraining process.
        if self.calculate_student_loss:
            with self.distiller.student_recorders, self.distiller.deliveries:
                student_losses = self.student(
                    batch_inputs, data_samples, mode='loss')
            losses.update(add_prefix(student_losses, 'student'))
        else:
            with self.distiller.student_recorders, self.distiller.deliveries:
                if self.student_trainable:
                    _ = self.student(batch_inputs, data_samples, mode='loss')
                else:
                    with torch.no_grad():
                        _ = self.student(
                            batch_inputs, data_samples, mode='loss')

        if not self.distillation_stopped:
            # Automatically compute distill losses based on
            # `loss_forward_mappings`.
            # The required data already exists in the recorders.
            distill_losses, teacher_featuremaps_all, mask_spatial_all, mask_channel_all = \
                    self.distiller.compute_distill_losses()
            
            losses.update(add_prefix(distill_losses, 'distill'))
            
            # can be stopped by self.distiller.mask_learning_stopped
            if not self.distiller.mask_learning_stopped:
                loss = self.cascade_pass_teacher_model(teacher_featuremaps_all, mask_spatial_all, mask_channel_all, data_samples)
                keyname = 'teacher'
                losses.update(add_prefix(loss, keyname))
            else:
                # 如果这里有符合teacher的元素就删掉
                losses = {k: v for k, v in losses.items() if not k.startswith('teacher')}
        else:
            losses = {k: v for k, v in losses.items() if not k.startswith('distill')}

        return losses



    def train(self, mode: bool = True) -> None:
        """Set distiller's forward mode."""
        super().train(mode)
        if mode and self.teacher_norm_eval:
            for m in self.teacher.modules():
                if isinstance(m, _BatchNorm):
                    m.eval()


    def cascade_pass_teacher_model(self, teacher_featuremaps_all, mask_spatial_all, mask_channel_all, data_samples):

        assert (len(teacher_featuremaps_all) == len(mask_spatial_all) == len(mask_channel_all) == len(self.cascade_prompt)), 'length not matched'
        allowed_prompt = {'before_pts_backbone', 'before_pts_neck', 'before_head'}  # 允许的值集合
        assert set(self.cascade_prompt).issubset(allowed_prompt), "Illegal values in cascade_prompt"

        if 'before_pts_backbone' in self.cascade_prompt:
            assert len(mask_spatial_all) > 0, 'not enought mask'
            teacher_maskedmap = self.mask_featuremap(mask_spatial_all[0], mask_channel_all[0], teacher_featuremaps_all[0])

            teacher_maskedmap = self.teacher.pts_backbone(teacher_maskedmap)
            del mask_spatial_all[0]
            del mask_channel_all[0]
            del teacher_featuremaps_all[0]

            if 'before_pts_neck' in self.cascade_prompt:
                assert len(mask_spatial_all) > 0, 'not enought mask'
                teacher_maskedmap = (self.mask_featuremap(mask_spatial_all[0], mask_channel_all[0], teacher_maskedmap[0]), ) + teacher_maskedmap[1:]
                del mask_spatial_all[0]
                del mask_channel_all[0]
                del teacher_featuremaps_all[0]
            
            assert len(teacher_featuremaps_all[0]) > 1, 'the input of pts_neck must be tuple'
            teacher_maskedmap = self.teacher.pts_neck(teacher_maskedmap)

            if 'before_head' in self.cascade_prompt:
                assert len(mask_spatial_all) > 0, 'not enought mask'
                teacher_maskedmap = self.mask_featuremap(mask_spatial_all[0], mask_channel_all[0], teacher_maskedmap[0])
                del mask_spatial_all[0]
                del mask_channel_all[0]
                del teacher_featuremaps_all[0]
                teacher_maskedmap = [teacher_maskedmap]

            if self.teacher.with_bbox_head:
                return self.teacher.bbox_head.loss(teacher_maskedmap, data_samples)
            else:
                assert self.teacher.with_seg_head, 'downstream task must be either segmentation or detection'
                return self.teacher.seg_head.loss(teacher_maskedmap, data_samples)
        elif 'before_pts_neck' in self.cascade_prompt:
            assert len(mask_spatial_all) > 0, 'not enought mask'
            teacher_maskedmap = self.mask_featuremap(mask_spatial_all[0], mask_channel_all[0], teacher_featuremaps_all[0][0])
            featuremap_tuple = [teacher_maskedmap, teacher_featuremaps_all[0][1]]
            assert len(featuremap_tuple) > 1, 'the input of pts_neck must be tuple'
            del mask_spatial_all[0]
            del mask_channel_all[0]
            del teacher_featuremaps_all[0]
            teacher_maskedmap = self.teacher.pts_neck(featuremap_tuple)

            if 'before_head' in self.cascade_prompt:
                assert len(mask_spatial_all) > 0, 'not enought mask'
                teacher_maskedmap = self.mask_featuremap(mask_spatial_all[0], mask_channel_all[0], teacher_maskedmap[0])
                del mask_spatial_all[0]
                del mask_channel_all[0]
                del teacher_featuremaps_all[0]
                teacher_maskedmap = [teacher_maskedmap]

            if self.teacher.with_bbox_head:
                return self.teacher.bbox_head.loss(teacher_maskedmap, data_samples)
            else:
                assert self.teacher.with_seg_head, 'downstream task must be either segmentation or detection'
                return self.teacher.seg_head.loss(teacher_maskedmap, data_samples)
        elif 'before_head':
            assert len(mask_spatial_all) > 0, 'not enought mask'
            teacher_maskedmap = self.mask_featuremap(mask_spatial_all[0], mask_channel_all[0], teacher_featuremaps_all[0])
            del mask_spatial_all[0]
            del mask_channel_all[0]
            del teacher_featuremaps_all[0]

            if self.teacher.with_bbox_head:
                return self.teacher.bbox_head.loss(teacher_maskedmap, data_samples)
            else:
                assert self.teacher.with_seg_head, 'downstream task must be either segmentation or detection'
                return self.teacher.seg_head.loss(teacher_maskedmap, data_samples)


    def mask_featuremap(self, softmax_spatial_mask, softmax_channel_mask, featuremaps):
        softmax_spatial_mask = softmax_spatial_mask
        softmax_channel_mask = softmax_channel_mask.unsqueeze(2).unsqueeze(3)

        masked_featuremap = featuremaps * softmax_spatial_mask * softmax_channel_mask

        return masked_featuremap