# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py

# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by XXX
# --------------------------------------------------------

"""
Misc functions, including distributed helpers.

Mostly copy-paste from torchvision references.
"""
from typing import List, Optional, Tuple, Any
import hashlib

import torch
import torchvision
from torch import nn, Tensor, device
import torch.distributed as dist
import torch.nn.functional as F

from detectron2.layers import cat, shapes_to_tensor

from utils.constants import *


def move_dict_to_cpu(data):
    if isinstance(data, torch.Tensor):
        return data.cpu()
    elif isinstance(data, dict):
        return {k: move_dict_to_cpu(v) for k, v in data.items()}
    elif isinstance(data, list) or isinstance(data, tuple):
        return [move_dict_to_cpu(v) for v in data]
    else:
        return data

def strict_hash(s):
    return hashlib.md5(s.encode()).hexdigest()

def decode_entity_mask_to_indices(mask, padding_value=-4):
    indices_mask = mask * torch.arange(len(mask), device=mask.device)[:, None]
    indices_mask = (indices_mask * mask.float()).sum(dim=0)
    indices_mask[mask.sum(dim=0)==0] = padding_value
    return indices_mask

def prepare_vision_features(x, num_feature_levels, pe_layer, input_proj, level_embed):
    src = []
    pos = []
    size_list = []

    # disable mask, it does not affect performance
    for i in range(num_feature_levels):
        size_list.append(x[i].shape[-2:])
        pos.append(pe_layer(x[i], None).flatten(2))
        src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None])

        # flatten NxCxHxW to HWxNxC
        pos[-1] = pos[-1].permute(2, 0, 1)
        src[-1] = src[-1].permute(2, 0, 1)
    return src, pos, size_list

def pad_arbitrary_tensors(tensors, padding_value=0.):
    max_len = torch.stack([torch.tensor(x.shape) for x in tensors]).max(dim=0)[0]
    padded_tensor = torch.empty([len(tensors)] + max_len.tolist(), device=tensors[0].device).fill_(padding_value)
    for i, x in enumerate(tensors):
        padded_tensor[i, :x.shape[0], :x.shape[1]] = x
    return padded_tensor

def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes

class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], : img.shape[2]] = False
    elif tensor_list[0].ndim == 2:
        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, l = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, l), dtype=torch.bool, device=device)
        for txt, pad_txt, m in zip(tensor_list, tensor, mask):
            pad_txt[: txt.shape[0], : txt.shape[1]] = txt
            m[: txt.shape[1]] = False
    else:
        raise ValueError("not supported")
    return NestedTensor(tensor, mask)

def _collate_and_pad_divisibility(tensor_list: list, div=32):
    max_size = []
    for i in range(tensor_list[0].dim()):
        max_size_i = torch.max(
            torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)
        ).to(torch.int64)
        max_size.append(max_size_i)
    max_size = tuple(max_size)

    c,h,w = max_size
    pad_h = (div - h % div) if h % div != 0 else 0
    pad_w = (div - w % div) if w % div != 0 else 0
    max_size = (c,h+pad_h,w+pad_w)
    
    # work around for
    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
    # m[: img.shape[1], :img.shape[2]] = False
    # which is not yet supported in onnx
    padded_imgs = []
    padded_masks = []
    for img in tensor_list:
        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
        padded_imgs.append(padded_img)

        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
        padded_masks.append(padded_mask.to(torch.bool))
    
    return padded_imgs

# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
    max_size = []
    for i in range(tensor_list[0].dim()):
        max_size_i = torch.max(
            torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
        ).to(torch.int64)
        max_size.append(max_size_i)
    max_size = tuple(max_size)

    # work around for
    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
    # m[: img.shape[1], :img.shape[2]] = False
    # which is not yet supported in onnx
    padded_imgs = []
    padded_masks = []
    for img in tensor_list:
        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
        padded_imgs.append(padded_img)

        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
        padded_masks.append(padded_mask.to(torch.bool))

    tensor = torch.stack(padded_imgs)
    mask = torch.stack(padded_masks)

    return NestedTensor(tensor, mask=mask)

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

# TODO: add background to 
def get_class_names(name):
    if name is None:
        return None
    elif 'refcoco' in name:
        return ["background"]
    elif 'coco' in name:
        return COCO_PANOPTIC_CLASSES + ["background"]
    elif 'ade20k_full' in name:
        return ADE20K_847 + ["background"]
    elif 'ade' in name:
        return ADE_PANOPTIC_CLASSES + ["background"]
    elif 'scannet_41' in name:
        return SCAN_40 + ["background"]
    elif 'scannet_21' in name:
        return SCAN_20 + ["background"]
    elif 'sun' in name:
        return SUN_RGBD_37 + ["background"]
    elif 'voc' in name:
        return PASCAL_CLASSES + ["background"]
    elif name == 'cityscapes_fine_sem_seg_val':
        return CITYSCAPES + ["background"]
    elif name == 'cityscapes_fine_instance_seg_val':
        return CITYSCAPES_THING + ["background"]
    elif name in ['cityscapes_fine_panoptic_val']:
        return CITYSCAPES + ["background"]
    elif name == 'bdd10k_val_sem_seg':
        return BDD_SEM + ["background"]
    elif name == 'bdd10k_40_panoptic_val':
        return BDD_PANO + ["background"]
    elif 'vlp' in name:
        return ["background"]
    else:
        assert False, "text dataset name {} is not defined".format(name)

def get_iou(gt_masks, pred_masks, ignore_label=-1):
    rev_ignore_mask = ~(gt_masks == ignore_label)
    gt_masks = gt_masks.bool()
    n,h,w = gt_masks.shape
    intersection = ((gt_masks & pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)
    union = ((gt_masks | pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)
    ious = (intersection / union)
    return ious

class Spatial_ImageList(object):
    """
    Structure that holds a list of images (of possibly
    varying sizes) as a single tensor.
    This works by padding the images to the same size.
    The original sizes of each image is stored in `image_sizes`.

    Attributes:
        image_sizes (list[tuple[int, int]]): each tuple is (h, w).
            During tracing, it becomes list[Tensor] instead.
    """

    def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]):
        """
        Arguments:
            tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1
            image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can
                be smaller than (H, W) due to padding.
        """
        self.tensor = tensor
        self.image_sizes = image_sizes

    def __len__(self) -> int:
        return len(self.image_sizes)

    def __getitem__(self, idx) -> torch.Tensor:
        """
        Access the individual image in its original size.

        Args:
            idx: int or slice

        Returns:
            Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1
        """
        size = self.image_sizes[idx]
        return self.tensor[idx, ..., : size[0], : size[1]]

    @torch.jit.unused
    def to(self, *args: Any, **kwargs: Any) -> "Spatial_ImageList":
        cast_tensor = self.tensor.to(*args, **kwargs)
        return Spatial_ImageList(cast_tensor, self.image_sizes)

    @property
    def device(self) -> device:
        return self.tensor.device

    @staticmethod
    def from_tensors(
        tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0
    ) -> "Spatial_ImageList":
        """
        Args:
            tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or
                (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded
                to the same shape with `pad_value`.
            size_divisibility (int): If `size_divisibility > 0`, add padding to ensure
                the common height and width is divisible by `size_divisibility`.
                This depends on the model and many models need a divisibility of 32.
            pad_value (float): value to pad

        Returns:
            an `Spatial_ImageList`.
        """
        assert len(tensors) > 0
        assert isinstance(tensors, (tuple, list))
        for t in tensors:
            assert isinstance(t, torch.Tensor), type(t)

        image_sizes = [(im.shape[-3], im.shape[-2], im.shape[-1]) for im in tensors]

        image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes]
        max_size = torch.stack(image_sizes_tensor).max(0).values

        if size_divisibility > 1:
            stride = size_divisibility
            # the last two dims are H,W, both subject to divisibility requirement
            max_size[-2:] = (max_size[-2:] + (stride - 1)).div(stride, rounding_mode="floor") * stride

        # handle weirdness of scripting and tracing ...
        if torch.jit.is_scripting():
            max_size: List[int] = max_size.to(dtype=torch.long).tolist()
        else:
            if torch.jit.is_tracing():
                image_sizes = image_sizes_tensor

        if len(tensors) == 1:
            # This seems slightly (2%) faster.
            # TODO: check whether it's faster for multiple images as well
            image_size = image_sizes[0]
            padding_size = [0, max_size[-1] - image_size[2], 0, max_size[-2] - image_size[1]]
            batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
        else:
            # max_size can be a tensor in tracing mode, therefore convert to list
            batch_shape = [len(tensors)] + list(tensors[0].shape[:-3]) + list(max_size)
            batched_imgs = tensors[0].new_full(batch_shape, pad_value)
            for img, pad_img in zip(tensors, batched_imgs):
                pad_img[:img.shape[-3],:img.shape[-2],:img.shape[-1]].copy_(img)

        return Spatial_ImageList(batched_imgs.contiguous(), image_sizes)

def flatten_dict(d, parent_key='', sep='_'):
    return {f"{parent_key}{sep}{k}" if parent_key else k: v
            for kk, vv in d.items()
            for k, v in (flatten_dict(vv, kk, sep).items() if isinstance(vv, dict) else [(kk, vv)])}

def flatten_list(nested_list):
    if not isinstance(nested_list, list):
        return [nested_list]
    if not nested_list:
        return []
    return flatten_list(nested_list[0]) + flatten_list(nested_list[1:])
