
import copy
import logging

import cv2
import numpy as np
import sklearn.neighbors as nn
import torch
from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.structures import BitMasks, Instances
from fvcore.transforms.transform import Transform
from skimage.color import rgb2gray, rgb2lab
from torch.nn import functional as F

__all__ = ["ColorizationDatasetMapper"]

def colorization_collation(in_batch):
    # import ipdb; ipdb.set_trace()
    ret_list = []
    # import ipdb; ipdb.set_trace()
    for batch in in_batch:
        model_in = {}
        for k, v in batch.items():
            if k == 'image':
                img_lab = rgb2lab(v.cpu().numpy().transpose(1, 2, 0))
                img_l = img_lab[:, :, 0]
                img_ab = img_lab[:, :, 1:]
                img_l = torch.as_tensor(np.ascontiguousarray(img_l)).unsqueeze(0).to(torch.float32)
                model_in[k] = img_l
                model_in["image_ab"] = img_ab
            else:
                model_in[k] = v
        
        ret_list.append(model_in)

    return ret_list


class ColorizationTransform(Transform):
    def __init__(
        self,
        img_format,
        col_hull_path,
    ):
        super().__init__()
        assert img_format in ["BGR", "RGB"]
        self.is_rgb = img_format == "RGB"
        del img_format

        # Get convex hull of in gamut colors
        bins = np.load(col_hull_path)
        self.neighbors = nn.NearestNeighbors(n_neighbors=1, algorithm='auto').fit(bins)

        self._set_attributes(locals())

    def apply_coords(self, coords):
        return coords

    def apply_segmentation(self, segmentation):
        return segmentation

    def apply_image(self, img, interp=None):
        # import ipdb; ipdb.set_trace()

        if not self.is_rgb:
            img = img[:, :, [2, 1, 0]]
        
        # Convert to LAB color space
        img_lab = rgb2lab(img)
        img_l = img_lab[:, :, 0]
        img_ab = img_lab[:, :, 1:]

        # Discretize for output
        sem_seg = self.get_color_categories(img_ab)        

        return img_l, sem_seg
    
    def get_color_categories(self, img_ab):
        h, w, _ = img_ab.shape
        img_ab = img_ab.reshape(-1, 2)
        _, indices = self.neighbors.kneighbors(img_ab)
        col_mask = indices.reshape((h, w))
        # import ipdb; ipdb.set_trace()
        # # targets = np.eye(313)[col_mask]

        return col_mask


def build_transform_gen(cfg, is_train):
    """
    Create a list of default :class:`Augmentation` from config.
    Now it includes resizing and flipping.
    Returns:
        list[Augmentation]
    """
    assert is_train, "Only support training augmentation"
    image_size = cfg.INPUT.IMAGE_SIZE
    min_scale = cfg.INPUT.MIN_SCALE
    max_scale = cfg.INPUT.MAX_SCALE

    augmentation = []

    if cfg.INPUT.RANDOM_FLIP != "none":
        augmentation.append(
            T.RandomFlip(
                horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
                vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
            )
        )

    augmentation.extend([
        T.ResizeScale(
            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
        ),
        T.FixedSizeCrop(crop_size=(image_size, image_size)),
    ])

    return augmentation


class ColorizationDatasetMapper:
    """
    A callable which takes a dataset dict in Detectron2 Dataset format,
    and map it into a format used by MaskFormer for semantic segmentation.

    The callable currently does the following:

    1. Read the image from "file_name"
    2. Applies geometric transforms to the image and annotation
    3. Find and applies suitable cropping to the image and annotation
    4. Prepare image and annotation to Tensors
    """

    @configurable
    def __init__(
        self,
        is_train=True,
        *,
        augmentations,
        col_augmentations,
        image_format,
        half_prec,
        # ignore_label,
    ):
        """
        NOTE: this interface is experimental.
        Args:
            is_train: for training or inference
            augmentations: a list of augmentations or deterministic transforms to apply
            image_format: an image format supported by :func:`detection_utils.read_image`.
            ignore_label: the label that is ignored to evaluation
        """
        self.is_train = is_train
        self.tfm_gens = augmentations
        self.col_aug = col_augmentations
        self.img_format = image_format
        self.half_prec = half_prec
        # self.ignore_label = ignore_label

        logger = logging.getLogger(__name__)
        mode = "training" if is_train else "inference"
        logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")

    @classmethod
    def from_config(cls, cfg, is_train=True):
        augs = build_transform_gen(cfg, is_train)
        col_augs = [ColorizationTransform(img_format=cfg.INPUT.FORMAT, col_hull_path=cfg.INPUT.COL_HULL_PATH)]

        # Assume always applies to the training set.
        dataset_names = cfg.DATASETS.TRAIN
        # meta = MetadataCatalog.get(dataset_names[0])

        ret = {
            "is_train": is_train,
            "augmentations": augs,
            "col_augmentations" : col_augs,
            "image_format": cfg.INPUT.FORMAT,
            "half_prec" : cfg.SOLVER.AMP.ENABLED,
        }
        return ret

    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        assert self.is_train, "ColorizationDatasetMapper should only be used for training!"

        # import ipdb; ipdb.set_trace()

        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
        utils.check_image_size(dataset_dict, image)


        aug_input = T.AugInput(image)
        aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
        image = aug_input.image
        image, sem_seg_gt = self.col_aug[0].apply_image(image)


        # Pad image and segmentation label here!
        image = torch.as_tensor(np.ascontiguousarray(image)).unsqueeze(0)
        if self.half_prec:
            image = image.to(torch.float16)
        else:
            image = image.to(torch.float32)
        dataset_dict["half_precision"] = self.half_prec
        image_shape = (image.shape[-2], image.shape[-1])  # h, w
        if sem_seg_gt is not None:
            sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))

        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = image

        if sem_seg_gt is not None:
            dataset_dict["sem_seg"] = sem_seg_gt.long()

        # Prepare per-category binary masks
        if sem_seg_gt is not None:
            sem_seg_gt = sem_seg_gt.numpy()
            instances = Instances(image_shape)
            classes = np.unique(sem_seg_gt)
            # # remove ignored region
            # classes = classes[classes != self.ignore_label]
            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)

            masks = []
            for class_id in classes:
                masks.append(sem_seg_gt == class_id)

            if len(masks) == 0:
                # Some image does not have annotation (all ignored)
                instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
            else:
                masks = BitMasks(
                    torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
                )
                instances.gt_masks = masks.tensor

            dataset_dict["instances"] = instances



        return dataset_dict
