#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved

"""
DETR model and criterion classes.
"""
import torch
from torch.functional import Tensor
import torch.nn.functional as F
from torch import nn
from typing import Optional, List
import pickle
import numpy as np

from cvpods.layers import ShapeSpec, position_encoding_dict
from cvpods.modeling.backbone import Transformer
from cvpods.modeling.matcher import HungarianMatcher
from cvpods.structures import Boxes, ImageList, Instances
from cvpods.structures import boxes as box_ops
from cvpods.structures.boxes import generalized_box_iou
from cvpods.utils import comm
from cvpods.utils.metrics import accuracy


class DETR(nn.Module):
    def __init__(self, cfg, task_number):
        super(DETR, self).__init__()

        self.device = torch.device(cfg.MODEL.DEVICE)

        # Build Backbone
        self.backbone = cfg.build_backbone(
            cfg, input_shape=ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))
        )

        # Build Transformer
        self.transformer = Transformer(cfg)

        self.aux_loss = not cfg.MODEL.DETR.NO_AUX_LOSS
        self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
        self.num_queries = cfg.MODEL.DETR.NUM_QUERIES
        hidden_dim = self.transformer.d_model

        # Build FFN
        self.class_embed = nn.Linear(hidden_dim, self.num_classes + 1)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        # Build Object Queries
        self.query_embed = nn.Embedding(self.num_queries, hidden_dim)

        backbone_out_shapes = self.backbone.output_shape()[cfg.MODEL.DETR.IN_FEATURES]
        self.input_proj = nn.Conv2d(backbone_out_shapes.channels, hidden_dim, kernel_size=1)

        self.position_embedding = position_encoding_dict[cfg.MODEL.DETR.POSITION_EMBEDDING](
            num_pos_feats=hidden_dim // 2,
            temperature=cfg.MODEL.DETR.TEMPERATURE,
            normalize=True if cfg.MODEL.DETR.POSITION_EMBEDDING == "sine" else False,
            scale=None,
        )

        self.weight_dict = {
            "loss_ce": cfg.MODEL.DETR.CLASS_LOSS_COEFF,
            "loss_bbox": cfg.MODEL.DETR.BBOX_LOSS_COEFF,
            "loss_giou": cfg.MODEL.DETR.GIOU_LOSS_COEFF,
        }

        if self.aux_loss:
            self.aux_weight_dict = {}
            for i in range(cfg.MODEL.DETR.TRANSFORMER.NUM_DEC_LAYERS - 1):
                self.aux_weight_dict.update({k + f"_{i}": v for k, v in self.weight_dict.items()})
            self.weight_dict.update(self.aux_weight_dict)

        losses = ["labels", "boxes", "cardinality"]

        matcher = HungarianMatcher(
            cost_class=cfg.MODEL.DETR.COST_CLASS,
            cost_bbox=cfg.MODEL.DETR.COST_BBOX,
            cost_giou=cfg.MODEL.DETR.COST_GIOU,
        )

        self.criterion = SetCriterion(
            self.num_classes,
            matcher=matcher,
            weight_dict=self.weight_dict,
            eos_coef=cfg.MODEL.DETR.EOS_COEFF,
            losses=losses,
            task_number = task_number
        )

        self.post_processors = {"bbox": PostProcess()}

        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)

        if not cfg.MODEL.RESNETS.STRIDE_IN_1X1:
            # Custom or torch pretrain weights
            self.normalizer = lambda x: (x / 255.0 - pixel_mean) / pixel_std
        else:
            # MSRA pretrain weights
            self.normalizer = lambda x: (x - pixel_mean) / pixel_std

        self.to(self.device)



    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:
                * image: Tensor, image in (C, H, W) format.
                * instances: Instances
                Other information that's included in the original dicts, such as:
                * "height", "width" (int): the output resolution of the model, used in inference.
                  See :meth:`postprocess` for details.
        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a tensor storing the loss. Used during training only.
        """
        images = self.preprocess_image(batched_inputs)

        B, C, H, W = images.tensor.shape
        device = images.tensor.device

        mask = torch.ones((B, H, W), dtype=torch.bool, device=device)
        for img_shape, m in zip(images.image_sizes, mask):
            m[: img_shape[0], : img_shape[1]] = False

        src = self.backbone(images.tensor)["res5"]
        mask = F.interpolate(mask[None].float(), size=src.shape[-2:]).bool()[0]
        pos = self.position_embedding(src, mask)

        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)[0]

        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
        # print("DETACHING START")
        # out = {"pred_logits": outputs_class[-1].detach(), "pred_boxes": outputs_coord[-1]}
        # print("DETACHING DONE")

        if self.training:

            targets = self.convert_anno_format(batched_inputs)

            if self.aux_loss:
                out["aux_outputs"] = [
                    {"pred_logits": a, "pred_boxes": b}
                    for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
                ]
            loss_dict = self.criterion(out, targets)
            for k, v in loss_dict.items():
                loss_dict[k] = v * self.weight_dict[k] if k in self.weight_dict else v
            return loss_dict
        else:
            target_sizes = torch.stack(
                [
                    torch.tensor([
                        bi.get("height", img_size[0]),
                        bi.get("width", img_size[1])],
                        device=self.device)
                    for bi, img_size in zip(batched_inputs, images.image_sizes)
                ]
            )
            res = self.post_processors["bbox"](out, target_sizes)

            processed_results = []
            # for results_per_image, input_per_image, image_size in zip(
            for results_per_image, _, image_size in zip(res, batched_inputs, images.image_sizes):
                result = Instances(image_size)
                result.pred_boxes = Boxes(results_per_image["boxes"].float())
                result.scores = results_per_image["scores"].float()
                result.pred_classes = results_per_image["labels"]
                result.pred_score_dist = results_per_image["prob"]
                processed_results.append({"instances": result})

            return processed_results

    def preprocess_image(self, batched_inputs):
        """
        Normalize, pad and batch the input images.
        """
        images = [x["image"].float().to(self.device) for x in batched_inputs]
        images = [self.normalizer(img) for img in images]
        images = ImageList.from_tensors(images, self.backbone.size_divisibility)
        return images

    def convert_anno_format(self, batched_inputs):
        targets = []
        for bi in batched_inputs:
            target = {}
            h, w = bi["image"].shape[-2:]
            boxes = box_ops.box_xyxy_to_cxcywh(
                bi["instances"].gt_boxes.tensor / torch.tensor([w, h, w, h], dtype=torch.float32)
            )
            target["boxes"] = boxes.to(self.device)
            target["area"] = bi["instances"].gt_boxes.area().to(self.device)
            target["labels"] = bi["instances"].gt_classes.to(self.device)

            if hasattr(bi["instances"], "gt_classes_non_masked"):
                target["labels_non_masked"] = bi["instances"].gt_classes_non_masked.to(self.device)

            if hasattr(bi["instances"], "gt_masks"):
                target["masks"] = bi["instances"].gt_masks
            target["iscrowd"] = torch.zeros_like(target["labels"], device=self.device)
            target["orig_size"] = torch.tensor([bi["height"], bi["width"]], device=self.device)
            target["size"] = torch.tensor([h, w], device=self.device)
            target["image_id"] = torch.tensor(bi["image_id"], device=self.device)
            targets.append(target)

        return targets


class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, task_number):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their
                        relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        self.task_number = task_number 
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)
        self.task_distribution_path = "/home/naitik/projects/SGG_Continual/models/experiments/SGTR_short/task_distributions/c1_short.pkl"
        
        self.task_distribution = self.load_task_distribution()
    
    def load_task_distribution(self):
        """ Loads the pre-defined task wise class division """
        
        file = open(self.task_distribution_path, "rb")
        output = pickle.load(file)
        file.close()
        
        return output
    
    def filter_logits_and_gt(self,src_logits,target_classes):
        
        class_allowed = self.task_distribution[self.task_number]
        
        target_classes_new = []
        src_logits_new = []
        
        src_logits = src_logits.cpu()
        target_classes = target_classes.cpu()
        
        for i in range(len(src_logits)):
            if target_classes[i] in class_allowed :
                target_classes_new.append(target_classes[i])
                src_logits_new.append(np.array(src_logits[i]))
        
        target_classes_new = torch.Tensor(target_classes_new)
        src_logits_new = torch.Tensor(np.array(src_logits_new))
        
        src_logits_new = src_logits_new.to(device='cuda')
        target_classes_new = target_classes.to(device='cuda')
        
        return src_logits_new , target_classes_new
    
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert "pred_logits" in outputs
        del num_boxes

        src_logits = outputs["pred_logits"]

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])

        if targets[0].get('labels_non_masked') is not None:
            target_classes_no_mask = torch.cat(
                [t["labels_non_masked"][J] for t, (_, J) in zip(targets, indices)]
            )
        else:
            target_classes_no_mask = target_classes_o

        target_classes = torch.full(
            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
        )
        target_classes[idx] = target_classes_o
        valid_indx_1 = target_classes >= 0
        
        class_allowed = self.task_distribution[self.task_number]
        
        
        
        valid_indx_2 = [x in class_allowed for x in target_classes[valid_indx_1].tolist()]

        # src_logits_new,target_classes_new = self.filter_logits_and_gt(src_logits[valid_indx],target_classes[valid_indx])
        
        # print(src_logits_new)
        # print(new_target_classes_new)
        
        # print(valid_indx)
        # print(target_classes[valid_indx_1][valid_indx_2])
        # print(src_logits[valid_indx])
        print(target_classes[valid_indx_1][valid_indx_2])
        print(src_logits[valid_indx_1][valid_indx_2])
        if not (target_classes[valid_indx_1][valid_indx_2] == []) :
            loss_ce = F.cross_entropy(src_logits[valid_indx_1][valid_indx_2], 
                                      target_classes[valid_indx_1][valid_indx_2], self.empty_weight)
        else :
            loss_ce = 0 
            
        
        losses = {"loss_ce": loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            # print(src_logits[valid_indx_1][valid_indx_2].shape)
            # print(target_classes[valid_indx_1][valid_indx_2].shape)
            losses["class_error"] = 100 - accuracy(src_logits[valid_indx_1][valid_indx_2], target_classes[valid_indx_1][valid_indx_2])[0]

        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """
        Compute the cardinality error, ie the absolute error in the number of predicted non-empty
        boxes. This is not really a loss, it is intended for logging purposes only. It doesn't
        propagate gradients
        """
        del indices
        del num_boxes
        pred_logits = outputs["pred_logits"]
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {"cardinality_error": card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
        targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
        The target boxes are expected in format (center_x, center_y, h, w), normalized by the
        image size.
        """
        assert "pred_boxes" in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs["pred_boxes"][idx]

        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        valid_indx = target_classes_o >= 0

        loss_bbox = F.l1_loss(src_boxes[valid_indx],
                              target_boxes[valid_indx], reduction="none")

        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(
            generalized_box_iou(
                box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes)
            )
        )
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses

    def _get_src_permutation_idx(self, indices):
        """
        extract the src_idx of matching result with the batch indicator array,
        and concate them into whole tensor array
        Args:
            indices: Tuple[Tuple[Tensor, Tensor]]

        Returns:
            batch_idx : Tensor,
            tgt_idx: Tensor

        """
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            "labels": self.loss_labels,
            "cardinality": self.loss_cardinality,
            "boxes": self.loss_boxes,
        }
        assert loss in loss_map, f"do you really want to compute {loss} loss?"
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    def forward(self, outputs, targets, with_match_idx=False):
        """
        This performs the loss computation.

        Parameters:
            outputs: dict of tensors, see the output specification of the model for the format
            targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each
                      loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor(
            [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
        )

        if comm.get_world_size() > 1:
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / comm.get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of
        # each intermediate layer.
        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == "masks":
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == "labels":
                        # Logging is enabled only for the last layer
                        kwargs = {"log": False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

        if with_match_idx:
            return losses, indices
        else:
            return losses


class PostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""

    @torch.no_grad()
    def forward(self, outputs, target_sizes, topk=100):
        """
        Perform the computation
        Parameters:
            outputs: raw outputs of the model
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images
                        of the batch
                For evaluation, this must be the original image size (before any data augmentation)
                For visualization, this should be the image size after data augment,
                but before padding
        """
        out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = F.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)

        # convert to [x0, y0, x1, y1] format
        boxes_norm = box_ops.box_cxcywh_to_xyxy(out_bbox)
        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes_norm * scale_fct[:, None, :]

        results = [{"scores": s, "labels": l, "boxes": b, "prob": p , "boxes_norm": b_n}
                   for s, l, b, p, b_n in zip(scores, labels, boxes, prob, boxes_norm)]

        return results


class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
