import torch
from fcclip import FCCLIP

from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.modeling import META_ARCH_REGISTRY

from detectron2.data import MetadataCatalog
from detectron2.modeling import build_backbone, build_sem_seg_head
from detectron2.projects.point_rend.point_features import (
    get_uncertain_point_coords_with_randomness,
    point_sample,
)

from ..utils.misc import nested_tensor_from_tensor_list

from fcclip.modeling.criterion import (
    SetCriterion, sigmoid_ce_loss_jit, dice_loss_jit, calculate_uncertainty
)
from fcclip.custom.task_vectors import task_vector_weight
from fcclip.modeling.matcher import HungarianMatcher

from fcclip.fcclip import VILD_PROMPT


class SetContinualCriterion(SetCriterion):
    def __init__(self, *args, **kwargs):
        self.masking_loss = kwargs.pop('masking_loss', False)
        self.masking_loss_reverse = kwargs.pop('masking_loss_reverse', False)
        super().__init__(*args, **kwargs)

    def masking_loss_func(self, loss):
        # through the batch or not
        # use median or not
        # use standard deviation or not
        if self.masking_loss_reverse:
            return loss[torch.where(loss < loss.mean())]
        else:
            return loss[torch.where(loss > loss.mean())]

    def loss_labels(self, outputs, targets, indices, num_masks):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert "pred_logits" in outputs
        src_logits = outputs["pred_logits"].float()

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
        )
        target_classes[idx] = target_classes_o

        # TODO: change to sigmoid? or control the softmax
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction='none')
        if self.masking_loss:
            losses = {"loss_ce": self.masking_loss_func(loss_ce).mean()}
        else:
            losses = {"loss_ce": loss_ce.mean()}
        return losses

    def loss_masks(self, outputs, targets, indices, num_masks):
        """Compute the losses related to the masks: the focal loss and the dice loss.
        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        masks = [t["masks"] for t in targets]
        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]

        # No need to upsample predictions as we are using normalized coordinates :)
        # N x 1 x H x W
        src_masks = src_masks[:, None]
        target_masks = target_masks[:, None]

        with torch.no_grad():
            # sample point_coords
            point_coords = get_uncertain_point_coords_with_randomness(
                src_masks,
                lambda logits: calculate_uncertainty(logits),
                self.num_points,
                self.oversample_ratio,
                self.importance_sample_ratio,
            )
            # get gt labels
            point_labels = point_sample(
                target_masks,
                point_coords,
                align_corners=False,
            ).squeeze(1)

        point_logits = point_sample(
            src_masks,
            point_coords,
            align_corners=False,
        ).squeeze(1)

        loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks)
        loss_dice = dice_loss_jit(point_logits, point_labels, num_masks)

        if self.masking_loss:
            losses = {
                "loss_mask": self.masking_loss_func(loss_mask).sum(),
                "loss_dice": self.masking_loss_func(loss_dice).sum()
            }
        else:
            losses = {
                "loss_mask": loss_mask.sum(),
                "loss_dice": loss_dice.sum(),
            }
        return losses


@META_ARCH_REGISTRY.register()
class FCCLIPContinual(FCCLIP):
    @configurable
    def __init__(self, *args, **kwargs):
        model_weight_path = kwargs.pop("model_weight_path", None)
        self.finetune_task_arith = kwargs.pop('finetune_task_arith', False)
        self.finetune_task_arith_prev = kwargs.pop('finetune_task_arith_prev', False)

        self.progress_flag = kwargs.pop('progress_flag', False)
        self.progress_change_iter = kwargs.pop('progress_change_iter', None)
        self.progress_freeze_param = kwargs.pop('progress_freeze_param', None)
        self.scaling_coef = kwargs.pop('scaling_coef', None)
        self.current_iter = 1

        self.context_template = kwargs.pop('context_template', "")

        super().__init__(*args, **kwargs)
        self.pretrained_weight = torch.load(model_weight_path)
        self.prev_weight = self.pretrained_weight

    def prepare_class_names_from_metadata(self, metadata, train_metadata):
        def split_labels(x):
            res = []
            for x_ in x:
                x_ = x_.replace(', ', ',')
                x_ = x_.split(',')  # there can be multiple synonyms for single class
                res.append(x_)
            return res

        # get text classifier
        try:
            class_names = split_labels(metadata.stuff_classes)  # it includes both thing and stuff
            train_class_names = split_labels(train_metadata.stuff_classes)
        except:
            # this could be for insseg, where only thing_classes are available
            class_names = split_labels(metadata.thing_classes)
            train_class_names = split_labels(train_metadata.thing_classes)
        train_class_names = {l for label in train_class_names for l in label}
        category_overlapping_list = []
        for test_class_names in class_names:
            is_overlapping = not set(train_class_names).isdisjoint(set(test_class_names))
            category_overlapping_list.append(is_overlapping)
        category_overlapping_mask = torch.tensor(
            category_overlapping_list, dtype=torch.long)

        def fill_all_templates_ensemble(x_=''):
            res = []
            for x in x_:
                for template in VILD_PROMPT:
                    if self.context_template == "":
                        res.append(template.format(x))
                    else:
                        sentence = template.format(x).rstrip('.')
                        sentence = sentence.replace('in the scene', '')
                        res.append(sentence+f", {self.context_template}.")
            return res, len(res) // len(VILD_PROMPT)

        num_templates = []
        templated_class_names = []
        for x in class_names:
            templated_classes, templated_classes_num = fill_all_templates_ensemble(x)
            templated_class_names += templated_classes
            num_templates.append(templated_classes_num)  # how many templates for current classes
        class_names = templated_class_names
        # print("text for classification:", class_names)
        return category_overlapping_mask, num_templates, class_names

    def freeze_some_param(self, freeze_param_names):
        for module_param_name, param in self.sem_seg_head.named_parameters():
            for name in freeze_param_names:
                if name in module_param_name:
                    param.requires_grad = False

    def progressive_task_arithmetic(self):
        # task_vector_weight
        if self.finetune_task_arith_prev:
            # check the validity of the prev_weight
            self.sem_seg_head = task_vector_weight(self.prev_weight, self.sem_seg_head,
                                                   scaling_coef=self.scaling_coef)
            with torch.no_grad():
                new_state_dict = {}
                for key, param in self.sem_seg_head.state_dict().items():
                    new_state_dict[key] = param.detach().clone().cpu()
                self.prev_weight = new_state_dict
        else:
            self.sem_seg_head = task_vector_weight(self.pretrained_weight['model'], self.sem_seg_head)

    def progressive_freeze(self):
        self.freeze_some_param(self.progress_freeze_param)

    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())

        # Loss parameters:
        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT

        # loss weights
        class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT

        # building criterion
        matcher = HungarianMatcher(
            cost_class=class_weight,
            cost_mask=mask_weight,
            cost_dice=dice_weight,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
        )

        weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}

        if deep_supervision:
            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
            aux_weight_dict = {}
            for i in range(dec_layers - 1):
                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)

        losses = ["labels", "masks"]

        criterion = SetContinualCriterion(
            sem_seg_head.num_classes,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=no_object_weight,
            losses=losses,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
            masking_loss=cfg.MODEL.CONTINUAL.MASKING_LOSS,
            masking_loss_reverse=cfg.MODEL.CONTINUAL.MASKING_LOSS_REVERSE,
        )

        return {
            "backbone": backbone,
            "sem_seg_head": sem_seg_head,
            "criterion": criterion,
            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
            "train_metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
            "test_metadata": MetadataCatalog.get(cfg.DATASETS.TEST[0]),
            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
            "sem_seg_postprocess_before_inference": (
                    cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
                    or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
                    or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
            ),
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
            # inference
            "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
            "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
            "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
            "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
            "geometric_ensemble_alpha": cfg.MODEL.FC_CLIP.GEOMETRIC_ENSEMBLE_ALPHA,
            "geometric_ensemble_beta": cfg.MODEL.FC_CLIP.GEOMETRIC_ENSEMBLE_BETA,
            "ensemble_on_valid_mask": cfg.MODEL.FC_CLIP.ENSEMBLE_ON_VALID_MASK,
            "model_weight_path": cfg.MODEL.WEIGHTS,
            "finetune_task_arith": cfg.MODEL.CONTINUAL.FINETUNE_TASK_ARITH,
            "finetune_task_arith_prev": cfg.MODEL.CONTINUAL.FINETUNE_TASK_ARITH_PREV,
            # progress
            "progress_flag": cfg.MODEL.CONTINUAL.PROGRESS.FLAG,
            "progress_change_iter": cfg.MODEL.CONTINUAL.PROGRESS.CHANGE_ITER,
            "progress_freeze_param": cfg.MODEL.CONTINUAL.PROGRESS.FREEZE_PARAM,
            "scaling_coef": cfg.MODEL.CONTINUAL.PROGRESS.SCALING_COEF,
            # context
            "context_template": cfg.MODEL.CONTINUAL.CONTEXT.CONTEXT_TEMPLATE
        }

    def forward(self, batched_inputs):
        if self.training and self.finetune_task_arith:
            self.progressive_task_arithmetic()

        if self.training and self.progress_flag and \
                (self.current_iter % self.progress_change_iter == 0):
            self.progressive_task_arithmetic()
            # self.progressive_freeze()

        self.current_iter += 1
        return super().forward(batched_inputs)
