from fcclip import FCCLIP
from typing import Tuple

import torch
import os
import numpy as np
from torch import nn
from torch.nn import functional as F

from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
from detectron2.modeling.backbone import Backbone
from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
from detectron2.utils.memory import retry_if_cuda_oom

from fcclip.fcclip import get_classification_logits, MaskPooling

from fcclip.custom.criterion_point import SetCriterionPoint
from fcclip.custom.matcher_point import HungarianPointMatcher

# To use SAM
from segment_anything import sam_model_registry, SamPredictor


@META_ARCH_REGISTRY.register()
class FCCLIPoint(FCCLIP):
    @configurable
    def __init__(self, *args, **kwargs):
        sam_predictor = kwargs.pop('sam_predictor', None)
        sam_storage = kwargs.pop('sam_storage', False)
        prompt_tuning = kwargs.pop('prompt_tuning', False)
        super().__init__(*args, **kwargs)
        self.sam_predictor = sam_predictor
        self.sam_storage = sam_storage
        if prompt_tuning:
            self.freeze_everything()
            self.tuning_only_prompt()

    def freeze_everything(self):
        for param in self.sem_seg_head.parameters():
            param.requires_grad = False

    def tuning_only_prompt(self):
        # TODO: Tuning only prompt (not mask_embed and class_embed)
        freeze_param_names = ['mask_embed', 'class_embed', 'query_feat', 'query_embed']
        for module_param_name, param in self.sem_seg_head.named_parameters():
            for name in freeze_param_names:
                if name in module_param_name:
                    param.requires_grad = True

    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())

        # Loss parameters:
        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT

        # loss weights
        class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT

        # building criterion
        matcher = HungarianPointMatcher(
            cost_class=class_weight,
            cost_mask=mask_weight,
            cost_dice=dice_weight,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
        )

        weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}

        if deep_supervision:
            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
            aux_weight_dict = {}
            for i in range(dec_layers - 1):
                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)

        losses = ["labels", "masks"]

        criterion = SetCriterionPoint(
            sem_seg_head.num_classes,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=no_object_weight,
            losses=losses,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
        )

        config_dict = {
            "backbone": backbone,
            "sem_seg_head": sem_seg_head,
            "criterion": criterion,
            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
            "train_metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
            "test_metadata": MetadataCatalog.get(cfg.DATASETS.TEST[0]),
            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
            "sem_seg_postprocess_before_inference": (
                    cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
                    or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
                    or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
            ),
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
            # inference
            "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
            "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
            "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
            "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
            "geometric_ensemble_alpha": cfg.MODEL.FC_CLIP.GEOMETRIC_ENSEMBLE_ALPHA,
            "geometric_ensemble_beta": cfg.MODEL.FC_CLIP.GEOMETRIC_ENSEMBLE_BETA,
            "ensemble_on_valid_mask": cfg.MODEL.FC_CLIP.ENSEMBLE_ON_VALID_MASK
        }

        # SAM Point Sampling
        if cfg.INPUT.SAM.DIRECT:
            sam_checkpoint = "./sam_vit_l_0b3195.pth"
            model_type = "vit_l"
            sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).cuda()
            sam_predictor = SamPredictor(sam)
            config_dict['sam_predictor'] = sam_predictor
        elif cfg.INPUT.SAM.STORAGE:
            config_dict['sam_storage'] = cfg.INPUT.SAM.STORAGE

        config_dict['prompt_tuning'] = cfg.MODEL.PROMPT_TUNING

        return config_dict

    def prepare_targets(self, targets, images):
        new_targets = []
        for targets_per_image in targets:
            # preprocess point
            points_orig = targets_per_image.gt_points_orig
            points = targets_per_image.gt_points
            points_labels = targets_per_image.gt_points_labels
            device = targets_per_image.gt_classes.device

            new_targets.append(
                {
                    "labels": targets_per_image.gt_classes,
                    "points": torch.tensor(points.copy(), device=device),
                    "points_orig": torch.tensor(points_orig.copy(), device=device),
                    "points_labels": torch.tensor(points_labels[:, None].copy(), device=device)
                }
            )
        return new_targets

    @torch.no_grad()
    def inference_sam(self, image_id, image, targets_per_image,
                      selectn=100, sam_storage=False):
        point_coords = targets_per_image.gt_points_orig
        image_size = targets_per_image.image_size
        # sampling selectn points (using confidence score)
        point_coords = np.round(point_coords * (np.array(image_size)[None, :]))
        sampling_points = []
        points_labels_num = []
        if not sam_storage:
            self.sam_predictor.set_image(image)

        for idx, point_coord in enumerate(point_coords):
            if not sam_storage:
                mask, _, _ = self.sam_predictor.predict(
                    point_coords=point_coord[None, :],
                    point_labels=np.array([1]),
                    multimask_output=False,
                )
                mask = mask.squeeze(0)
            else:
                mask = targets_per_image.sam_masks[idx]
            indice = np.argwhere(mask == 1)

            if len(indice) > 0:
                num_points = indice.shape[0]
                if num_points >= selectn:
                    random_indices = indice[np.random.choice(num_points, selectn, replace=False)]
                else:
                    random_indices = indice[np.random.choice(num_points, num_points, replace=False)]

                # (y, x) -> (x, y)
                sampling_points.append(random_indices[:, ::-1])
            else:
                random_indices = point_coord[None, :]
                sampling_points.append(random_indices)
            points_labels_num.append(random_indices.shape[0])

        if len(sampling_points) > 0:
            sampling_points = np.concatenate(sampling_points)[None, :].repeat(point_coords.shape[0], 0).astype(float)
            sampling_points /= np.array(image_size)[None, None, :]
            sampling_points_labels = np.zeros(sampling_points.shape[:-1])
            stack_num = 0
            for idx, n in enumerate(points_labels_num):
                sampling_points_labels[idx, stack_num:stack_num+n] = 1.
                stack_num += n
        else:
            sampling_points = np.array(sampling_points)[None, :]
            sampling_points_labels = np.zeros(sampling_points.shape[:-1])
        return sampling_points, sampling_points_labels

    def prepare_targets_sam(self, targets, images, images_id):
        new_targets = []
        for targets_per_image, image, image_id in zip(targets, images, images_id):
            # preprocess point
            points, points_labels = self.inference_sam(image_id, image, targets_per_image,
                                                       sam_storage=self.sam_storage)
            device = targets_per_image.gt_classes.device

            new_targets.append(
                {
                    "labels": targets_per_image.gt_classes,
                    "points": torch.tensor(points.copy(), device=device),
                    "points_orig": torch.tensor(targets_per_image.gt_points_orig.copy(), device=device),
                    "points_labels": torch.tensor(points_labels[:, None].copy(), device=device)
                }
            )
        return new_targets

    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:
                   * "image": Tensor, image in (C, H, W) format.
                   * "instances": per-region ground truth
                   * Other information that's included in the original dicts, such as:
                     "height", "width" (int): the output resolution of the model (may be different
                     from input resolution), used in inference.
        Returns:
            list[dict]:
                each dict has the results for one image. The dict contains the following keys:

                * "sem_seg":
                    A Tensor that represents the
                    per-pixel segmentation prediced by the head.
                    The prediction has shape KxHxW that represents the logits of
                    each class for each pixel.
                * "panoptic_seg":
                    A tuple that represent panoptic output
                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.
                        Each dict contains keys "id", "category_id", "isthing".
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images, self.size_divisibility)

        images_ori, images_id = None, None
        if (self.sam_predictor is not None or self.sam_storage) and self.training:
            images_ori = [x["image_ori"] for x in batched_inputs]
            images_id = [x["image_id"] for x in batched_inputs]

        features = self.backbone(images.tensor)
        text_classifier, num_templates = self.get_text_classifier()
        # Append void class weight
        text_classifier = torch.cat([text_classifier, F.normalize(self.void_embedding.weight, dim=-1)], dim=0)
        features['text_classifier'] = text_classifier
        features['num_templates'] = num_templates
        outputs = self.sem_seg_head(features)

        if self.training:
            # mask classification target
            if "instances" in batched_inputs[0]:
                gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
                if self.sam_predictor is not None or self.sam_storage:
                    targets = self.prepare_targets_sam(gt_instances, images_ori, images_id)
                else:
                    targets = self.prepare_targets(gt_instances, images)
            else:
                targets = None

            # bipartite matching-based loss
            losses = self.criterion(outputs, targets)

            for k in list(losses.keys()):
                if k in self.criterion.weight_dict:
                    losses[k] *= self.criterion.weight_dict[k]
                else:
                    # remove this loss if not specified in `weight_dict`
                    losses.pop(k)
            return losses
        else:
            mask_cls_results = outputs["pred_logits"]
            mask_pred_results = outputs["pred_masks"]

            # We ensemble the pred logits of in-vocab and out-vocab
            clip_feature = features["clip_vis_dense"]
            mask_for_pooling = F.interpolate(mask_pred_results, size=clip_feature.shape[-2:],
                                             mode='bilinear', align_corners=False)
            if "convnext" in self.backbone.model_name.lower():
                pooled_clip_feature = self.mask_pooling(clip_feature, mask_for_pooling)
                pooled_clip_feature = self.backbone.visual_prediction_forward(pooled_clip_feature)
            elif "rn" in self.backbone.model_name.lower():
                pooled_clip_feature = self.backbone.visual_prediction_forward(clip_feature, mask_for_pooling)
            else:
                raise NotImplementedError

            out_vocab_cls_results = get_classification_logits(pooled_clip_feature, text_classifier,
                                                              self.backbone.clip_model.logit_scale, num_templates)
            in_vocab_cls_results = mask_cls_results[..., :-1]  # remove void
            out_vocab_cls_results = out_vocab_cls_results[..., :-1]  # remove void

            # Reference: https://github.com/NVlabs/ODISE/blob/main/odise/modeling/meta_arch/odise.py#L1506
            out_vocab_cls_probs = out_vocab_cls_results.softmax(-1)
            in_vocab_cls_results = in_vocab_cls_results.softmax(-1)
            category_overlapping_mask = self.category_overlapping_mask.to(self.device)

            if self.ensemble_on_valid_mask:
                # Only include out_vocab cls results on masks with valid pixels
                # We empirically find that this is important to obtain reasonable AP/mIOU score with ResNet CLIP models
                valid_masking = (mask_for_pooling > 0).to(mask_for_pooling).sum(-1).sum(-1) > 0
                valid_masking = valid_masking.to(in_vocab_cls_results.dtype).unsqueeze(-1)
                alpha = torch.ones_like(in_vocab_cls_results) * self.geometric_ensemble_alpha
                beta = torch.ones_like(in_vocab_cls_results) * self.geometric_ensemble_beta
                alpha = alpha * valid_masking
                beta = beta * valid_masking
            else:
                alpha = self.geometric_ensemble_alpha
                beta = self.geometric_ensemble_beta

            cls_logits_seen = (
                    (in_vocab_cls_results ** (1 - alpha) * out_vocab_cls_probs ** alpha).log()
                    * category_overlapping_mask
            )
            cls_logits_unseen = (
                    (in_vocab_cls_results ** (1 - beta) * out_vocab_cls_probs ** beta).log()
                    * (1 - category_overlapping_mask)
            )
            cls_results = cls_logits_seen + cls_logits_unseen

            # This is used to filtering void predictions.
            is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:]
            mask_cls_probs = torch.cat([
                cls_results.softmax(-1) * (1.0 - is_void_prob),
                is_void_prob], dim=-1)
            mask_cls_results = torch.log(mask_cls_probs + 1e-8)

            # upsample masks
            mask_pred_results = F.interpolate(
                mask_pred_results,
                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
                mode="bilinear",
                align_corners=False,
            )

            del outputs

            processed_results = []
            for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
                    mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
            ):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                processed_results.append({})

                if self.sem_seg_postprocess_before_inference:
                    mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
                        mask_pred_result, image_size, height, width
                    )
                    mask_cls_result = mask_cls_result.to(mask_pred_result)

                # semantic segmentation inference
                if self.semantic_on:
                    r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
                    if not self.sem_seg_postprocess_before_inference:
                        r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
                    processed_results[-1]["sem_seg"] = r

                # panoptic segmentation inference
                if self.panoptic_on:
                    panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
                    processed_results[-1]["panoptic_seg"] = panoptic_r

                # instance segmentation inference
                if self.instance_on:
                    instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
                    processed_results[-1]["instances"] = instance_r

            return processed_results