# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py


import argparse
import glob
import json
import multiprocessing as mp
import os
# fmt: off
import sys

import torch
from PIL import Image

sys.path.insert(1, os.path.join(sys.path[0], '../..'))
sys.path.insert(1, os.path.join(sys.path[0], '../../demo'))
# fmt: on

import tempfile
import time
import warnings

import cv2
import numpy as np
import tqdm
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import ColorMode, Visualizer
from mask2former import add_maskformer2_config
from predictor import VisualizationDemo
import torch

# constants
WINDOW_NAME = "mask2former demo"
_OFF_WHITE = (1.0, 1.0, 240.0 / 255)


# # Metadata
# KITTI_STEP_SEM_SEG_CATEGORIES = [
#     {"name" : "masked", "id": 0, "trainId" : 0},     # ONLY FOR VISUALIZATION OF MASKING + CONFIDENCE
#     {"name" : "ST labels", "id": 1, "trainId" : 1},     # ONLY FOR VISUALIZATION OF MASKING + CONFIDENCE
# ]


# def _get_kitti_step_meta():
#     stuff_ids = [k["id"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]

#     # For semantic segmentation, this mapping maps from contiguous stuff id
#     # (in [0, 91], used in models) to ids in the dataset (used for processing results)
#     stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
#     stuff_classes = [k["name"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]

#     ret = {
#         "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
#         "stuff_classes": stuff_classes,
#     }
#     return ret

# metadata
KITTI_STEP_SEM_SEG_CATEGORIES = [
    {"color": [96, 96, 96], "name": "road", "id": 0, "trainId": 0},
    {"color": [138, 140, 255], "name": "sidewalk", "id": 1, "trainId": 1},
    {"color": [116, 112, 0], "name": "building", "id": 2, "trainId": 2},
    {"color": [184, 44, 78], "name": "wall", "id": 3, "trainId": 3},
    {"color": [252, 187, 187], "name": "fence", "id": 4, "trainId": 4},
    {"color": [0, 208, 255], "name": "pole", "id": 5, "trainId": 5},
    {"color": [250, 170, 30], "name": "traffic light", "id": 6, "trainId": 6},
    {"color": [220, 220, 0], "name": "traffic sign", "id": 7, "trainId": 7},
    {"color": [107, 142, 35], "name": "vegetation", "id": 8, "trainId": 8},
    {"color": [205, 178, 247], "name": "terrain", "id": 9, "trainId": 9},
    {"color": [70, 130, 180], "name": "sky", "id": 10, "trainId": 10},
    {"color": [220, 20, 60], "name": "person", "id": 11, "trainId": 11},
    {"color": [255, 77, 255], "name": "rider", "id": 12, "trainId": 12},
    {"color": [0, 0, 142], "name": "car", "id": 13, "trainId": 13},
    {"color": [0, 0, 70], "name": "truck", "id": 14, "trainId": 14},
    {"color": [255, 140, 0], "name": "bus", "id": 15, "trainId": 15},
    {"color": [0, 80, 100], "name": "train", "id": 16, "trainId": 16},
    {"color": [0, 0, 230], "name": "motorcycle", "id": 17, "trainId": 17},
    {"color": [119, 11, 32], "name": "bicycle", "id": 18, "trainId": 18},
]

def _get_kitti_step_meta():
    stuff_ids = [k["id"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]

    # For semantic segmentation, this mapping maps from contiguous stuff id
    # (in [0, 91], used in models) to ids in the dataset (used for processing results)
    stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
    stuff_classes = [k["name"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]
    stuff_colors = [k["color"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]

    ret = {
        "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
        "stuff_classes": stuff_classes,
        "stuff_colors": stuff_colors,
    }
    return ret



def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_maskformer2_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(description="maskformer2 demo for builtin configs")
    parser.add_argument(
        "--config-file",
        default="configs/coco/panoptic-segmentation/maskformer2_R50_bs16_50ep.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
    parser.add_argument("--video-input", help="Path to video file.")
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--output",
        help="A file or directory to save output visualizations. "
        "If not given, will show output in an OpenCV window.",
    )

    parser.add_argument(
        "--confidence-threshold",
        type=float,
        default=0.5,
        help="Minimum score for instance predictions to be shown",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )

    # Visualize GT and pre-TTT preds
    parser.add_argument(
        "--input_label_dir",
        type=str,
        default='',
    )


    parser.add_argument(
        "--ttt_output",
        type=str,
        default=''
    )

    parser.add_argument(
        '--topl',
        type=float,
        default=None,
        help="Top percent of labels to take"
    )

    parser.add_argument(
        '--mask',
        type=float,
        default=None,
        help="Masking ratio"
    )

    parser.add_argument(
        "--vis_dir",
        type=str,
        default='../visualizations/kitti_step/ttt_0009'
    )


    return parser


class CustomVisualizer(Visualizer):
    def __init__(self, image, metadata, instance_mode=ColorMode.IMAGE):
        super(CustomVisualizer, self).__init__(image, metadata, instance_mode=instance_mode)
    
    def draw_sem_seg(self, sem_seg, st_sem_seg, mae_sem_seg, area_threshold=None, alphas=[0.8, 0.8, 0.8]):
        """
        Draw semantic segmentation predictions/labels.

        Args:
            sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
                Each value is the integer label of the pixel.
            area_threshold (int): segments with less than `area_threshold` are not drawn.
            alpha (float): the larger it is, the more opaque the segmentations are.

        Returns:
            output (VisImage): image object with visualizations.
        """

        # import ipdb; ipdb.set_trace()

        if isinstance(sem_seg, torch.Tensor):
            sem_seg = sem_seg.numpy()
        labels, areas = np.unique(sem_seg, return_counts=True)
        sorted_idxs = np.argsort(-areas).tolist()
        labels = labels[sorted_idxs]
        for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
            try:
                mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
            except (AttributeError, IndexError):
                print("No mask color defined by label class: " + str(label))
                mask_color = None

            binary_mask = (sem_seg == label).astype(np.uint8)
            # text = self.metadata.stuff_classes[label]
            self.draw_binary_mask(
                binary_mask,
                color=mask_color,
                edge_color=_OFF_WHITE,
                # text=text,
                alpha=0.8,
                # area_threshold=area_threshold,
            )


        # if isinstance(st_sem_seg, torch.Tensor):
        #     st_sem_seg = st_sem_seg.numpy()
        # # labels, areas = np.unique(st_sem_seg, return_counts=True)
        # # sorted_idxs = np.argsort(-areas).tolist()
        # # labels = labels[sorted_idxs]
        # # for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
        # #     try:
        # #         mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
        # #     except (AttributeError, IndexError):
        # #         print("No mask color defined by label class: " + str(label))
        # #         mask_color = None

        # #     binary_mask = (st_sem_seg == label).astype(np.uint8)
        # #     # text = self.metadata.stuff_classes[label]
        # #     self.draw_binary_mask(
        # #         binary_mask,
        # #         color=mask_color,
        # #         edge_color=_OFF_WHITE,
        # #         # text=text,
        # #         alpha=alphas[0],
        # #         area_threshold=area_threshold,
        # #     )

        # # Draw background
        # binary_mask = (st_sem_seg == 255).astype(np.uint8)
        # color = (1., 1., 1.)
        # # text = self.metadata.stuff_classes[label]
        # self.draw_binary_mask(
        #     binary_mask,
        #     color=color,
        #     # edge_color=_OFF_WHITE,
        #     # text=text,
        #     alpha=0.8,
        #     # area_threshold=area_threshold,
        # )


        # import ipdb; ipdb.set_trace()

        # if isinstance(mae_sem_seg, torch.Tensor):
        #     mae_sem_seg = mae_sem_seg.numpy()
        # labels, areas = np.unique(mae_sem_seg, return_counts=True)
        # sorted_idxs = np.argsort(-areas).tolist()
        # labels = labels[sorted_idxs]
        # for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
        #     try:
        #         mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
        #     except (AttributeError, IndexError):
        #         print("No mask color defined by label class: " + str(label))
        #         mask_color = None

        #     binary_mask = (mae_sem_seg == label).astype(np.uint8)
        #     # text = self.metadata.stuff_classes[label]
        #     self.draw_binary_mask(
        #         binary_mask,
        #         color=mask_color,
        #         edge_color=_OFF_WHITE,
        #         # text=text,
        #         alpha=alphas[1],
        #         area_threshold=area_threshold,
        #     )
        
        return self.output


def test_ttt_update_vis(image, label_idx, st_mask, mae_mask, cfg, instance_mode):
    """
    Args:
        image (np.ndarray): an image of shape (H, W, C) (in BGR order).
            This is the format used by OpenCV.
    Returns:
        predictions (dict): the output of the model.
        vis_output (VisImage): the visualized image output.
    """

    meta = _get_kitti_step_meta()
    metadata = MetadataCatalog.get(
        cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
    )

    # import ipdb; ipdb.set_trace()

    # del metadata.stuff_classes
    # metadata.set(stuff_classes=meta["stuff_classes"][:])
    # try:
    #     del metadata.stuff_colors
    # except:
    #     pass
    # metadata.set(stuff_colors=meta["stuff_colors"][:])


    vis_output = None
    # Convert image from OpenCV BGR format to Matplotlib RGB format.
    image = image[:, :, ::-1]

    # import ipdb; ipdb.set_trace()
    label_idx = label_idx.squeeze()

    h, w = image.shape[:2]
    h_extra = h % 16
    w_extra = w % 16
    # cropped_img = img[:, h_extra // 2:h - (h_extra - h_extra // 2), w_extra // 2:w - (w_extra - w_extra // 2)]

    visualizer = CustomVisualizer(image[h_extra // 2:h - (h_extra - h_extra // 2), w_extra // 2:w - (w_extra - w_extra // 2), :], 
                                    metadata, instance_mode=instance_mode)
    st_mask = st_mask[h_extra // 2:h - (h_extra - h_extra // 2), w_extra // 2:w - (w_extra - w_extra // 2)]
    mae_mask = mae_mask[h_extra // 2:h - (h_extra - h_extra // 2), w_extra // 2:w - (w_extra - w_extra // 2)]
    label_idx = label_idx[h_extra // 2:h - (h_extra - h_extra // 2), w_extra // 2:w - (w_extra - w_extra // 2)]

    # import ipdb; ipdb.set_trace()
    vis_output = visualizer.draw_sem_seg(
                # label_idx.to(torch.device("cpu")),
                label_idx,
                st_mask.to(torch.device("cpu")), 
                mae_mask.to(torch.device("cpu")), 
                alphas=[0.8, 0.5, 1.0]
    )

    return vis_output



def mae_image(img, drop_ratio):
    # import ipdb; ipdb.set_trace()
    # Fix image
    img = img.unsqueeze(0)        # (C, H, W)
    # Patch size is 16, crop image to closest
    _, h, w = img.shape
    h_extra = h % 16
    w_extra = w % 16
    cropped_img = img[:, h_extra // 2:h - (h_extra - h_extra // 2), w_extra // 2:w - (w_extra - w_extra // 2)]
    cropped_img = torch.as_tensor(cropped_img)
    _, new_h, new_w = cropped_img.shape
    
    patched_img = patchify(cropped_img)

    _, mask, _ = random_masking(patched_img, drop_ratio)
    masked_img = patched_img * mask.unsqueeze(-1)
    masked_img = unpatchify(masked_img, new_h, new_w)
    masked_img = torch.nn.functional.pad(masked_img, (w_extra // 2,
                                                      w_extra - w_extra // 2,
                                                      h_extra // 2,
                                                      h_extra - h_extra // 2,),
                                        "constant", 255
                                        )
    assert masked_img.shape == img.shape
    return masked_img.type(torch.uint8)

def patchify(imgs):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p = 16
    imgs = imgs.unsqueeze(0)
    assert imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0

    h = imgs.shape[2] // p
    w = imgs.shape[3] // p
    x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 1))
    
    return x

def unpatchify(x, h, w):
    """
    x: (N, L, patch_size**2 *1)
    imgs: (N, 1, H, W)
    """
    p = 16
    # h = w = int(x.shape[1]**.5)
    assert (h // p) * (w // p) == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h // p, w // p, p, p, 1))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 1, h, w))

    return imgs.squeeze(0)

def random_masking(x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """

    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))
    
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
    
    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is remove, 1 is keep
    mask = torch.zeros([N, L], device=x.device)
    mask[:, :len_keep] = 1
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore



if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    args = get_parser().parse_args()
    setup_logger(name="fvcore")
    logger = setup_logger()
    logger.info("Arguments: " + str(args))

    cfg = setup_cfg(args)

    # import ipdb; ipdb.set_trace()

    demo = VisualizationDemo(cfg, instance_mode=ColorMode.SEGMENTATION)

    if args.input:
        if len(args.input) == 1:
            args.input = glob.glob(os.path.expanduser(args.input[0]))
            assert args.input, "The input path(s) was not found"
        for path in tqdm.tqdm(args.input, disable=not args.output):
            # use PIL, to be consistent with evaluation
            img = read_image(path, format="BGR")
            # import ipdb; ipdb.set_trace()
            # pred_path = os.path.join('exp_dir/gtd_0.1/0018_rand_used0.1_mask0.8/16_win',
            pred_path = os.path.join('exp_dir/custom_16/gtd_0.3/berkeley_rand_used0.1_mask0.8/16_win',
                                    # os.path.basename(path).split('.')[0] + '_' + str(args.topl) + '.png')
                                    os.path.basename(path).split('.')[0] + '_preds.png')
            try:
                predicted_img = read_image(pred_path, format="BGR")
                predicted_img = predicted_img[:, :, 0]
            except:
                continue
            start_time = time.time()

            # import ipdb; ipdb.set_trace()

            predictions, visualized_output = demo.run_on_image(img)
            logger.info(
                "{}: {} in {:.2f}s".format(
                    path,
                    "detected {} instances".format(len(predictions["instances"]))
                    if "instances" in predictions
                    else "finished",
                    time.time() - start_time,
                )
            )

            preds = predictions["sem_seg"]
            label_idx = torch.argmax(preds, dim=0, keepdim=True)


            # import ipdb; ipdb.set_trace()

            # Compute ST label mask
            # Fractional
            # Take top args.topl labels
            probs = torch.gather(preds, 0, label_idx).squeeze()

            rows, cols = probs.shape
            topl = int((rows * cols) * args.topl)
            best_probs = torch.argsort(probs.flatten(), descending=True)[:topl]
            row_idx = best_probs // cols
            col_idx = best_probs % cols
            st_mask = torch.ones_like(probs, dtype=torch.long, device=label_idx.device) * 255
            st_mask[row_idx, col_idx] = 1


            # Compute masked mask
            mae_mask = torch.ones_like(probs, dtype=torch.long, device=label_idx.device) * 255
            mae_mask = mae_image(mae_mask, drop_ratio=args.mask).squeeze(0)


            os.makedirs(args.vis_dir, exist_ok=True)
            # visualized_output = test_ttt_update_vis(predicted_img, label_idx, st_mask, mae_mask,
            #                                         cfg, ColorMode.SEGMENTATION)
            visualized_output = test_ttt_update_vis(img, predicted_img, st_mask, mae_mask,
                                        cfg, ColorMode.SEGMENTATION)
            
            if os.path.isdir(args.vis_dir):
                vis_out = os.path.join(args.vis_dir,
                                            os.path.basename(path).split('.')[0] + '_' + str(args.topl) + '.png')
            else:
                raise NotImplementedError
            visualized_output.save(vis_out)

