# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

import argparse
import glob
import multiprocessing as mp
import os
import time
import cv2
import tqdm
import json
import numpy as np

from detectron2.config import get_cfg

from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
from open_vocab_seg import add_ovseg_config

from open_vocab_seg.utils import VisualizationDemo
from datasets.kitti360_labels import kitti360_labels

# constants
WINDOW_NAME = "Open vocabulary segmentation"

# ADE
# def read_ade_classes():
#     with open("./ov-seg-clip/open_clip_training/openclip_data/ade20k_150/ade20k_150_classnames.txt") as f:
#         lines = f.readlines()
#         return [l.strip().split(":")[-1] for l in lines]
#     # return ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road, route', 'bed', 'window', 'grass', 'cabinet', 'sidewalk, pavement', 'person', 'earth, ground', 'door', 'table', 'mountain, mount', 'plant', 'curtain', 'chair', 'car', 'water', 'painting, picture', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock, stone', 'wardrobe, closet, press', 'lamp', 'tub', 'rail', 'cushion', 'base, pedestal, stand', 'box', 'column, pillar', 'signboard, sign', 'chest of drawers, chest, bureau, dresser', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator, icebox', 'grandstand, covered stand', 'path', 'stairs', 'runway', 'case, display case, showcase, vitrine', 'pool table, billiard table, snooker table', 'pillow', 'screen door, screen', 'stairway, staircase', 'river', 'bridge, span', 'bookcase', 'blind, screen', 'coffee table', 'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm, palm tree', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel, hut, hutch, shack, shanty', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv', 'plane', 'dirt track', 'clothes', 'pole', 'land, ground, soil', 'bannister, banister, balustrade, balusters, handrail', 'escalator, moving staircase, moving stairway', 'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'buffet, counter, sideboard', 'poster, posting, placard, notice, bill, card', 'stage', 'van', 'ship', 'fountain', 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', 'canopy', 'washer, automatic washer, washing machine', 'plaything, toy', 'pool', 'stool', 'barrel, cask', 'basket, handbasket', 'falls', 'tent', 'bag', 'minibike, motorbike', 'cradle', 'oven', 'ball', 'food, solid food', 'step, stair', 'tank, storage tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket, cover', 'sculpture', 'hood, exhaust hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass, drinking glass', 'clock', 'flag']

# SCANNET_CLASSES = read_ade_classes()
# ONLY_VAL_CLASSES = read_ade_classes()

# SCannet++
# CLASSES_FILE = open("./data/scannet++/semantic_classes.json")
# SCANNET_CATEGORIES = json.load(CLASSES_FILE)
# SCANNET_CLASSES = [x["name"] for x in SCANNET_CATEGORIES]

# VAL_IDS_PATH = "./data/scannet++/semantic_classes_val.txt"
# def read_indices():
#     with open(VAL_IDS_PATH) as f:
#         lines = f.readlines()
#         return [int(l.strip()) for l in lines]
# ONLY_VAL_INDICES = read_indices()
# ONLY_VAL_CLASSES = [SCANNET_CLASSES[i] for i in ONLY_VAL_INDICES]

# KITTI
kitti_classes_gt = [x.name for x in kitti360_labels]
kitti_classes = kitti_classes_gt[6:-1]
kitti_colors_gt = [x.color for x in kitti360_labels]
kitti_colors = kitti_colors_gt[6:-1]
SCANNET_CLASSES = kitti_classes_gt
ONLY_VAL_CLASSES = kitti_classes


def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    # for poly lr schedule
    add_deeplab_config(cfg)
    add_ovseg_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(description="Detectron2 demo for open vocabulary segmentation")
    parser.add_argument(
        "--config-file",
        default="configs/ovseg_swinB_vitL_demo.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--gt",
        nargs="+",
        help="A list of space separated gt images; "
        "or a single glob pattern such as 'directory/*.png'",
    )
    parser.add_argument(
        "--class-names",
        nargs="+",
        help="A list of user-defined class_names"
    )
    parser.add_argument(
        "--output",
        help="A file or directory to save output visualizations. "
        "If not given, will show output in an OpenCV window.",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    return parser


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    args = get_parser().parse_args()
    setup_logger(name="fvcore")
    logger = setup_logger()
    logger.info("Arguments: " + str(args))

    cfg = setup_cfg(args)

    demo = VisualizationDemo(cfg)
    # class_names = args.class_names
    class_names = ONLY_VAL_CLASSES
    class_names_gt = SCANNET_CLASSES
    if args.input:
        if len(args.input) == 1:
            args.input = glob.glob(os.path.expanduser(args.input[0]))
            assert args.input, "The input path(s) was not found"
        if len(args.gt) == 1:
            args.gt = glob.glob(os.path.expanduser(args.gt[0]))
            assert args.gt, "The gt path(s) was not found"
        for i, path in enumerate(tqdm.tqdm(args.input, disable=not args.output)):
            # use PIL, to be consistent with evaluation
            img = read_image(path, format="BGR")
            start_time = time.time()
            gt = None
            if args.gt:
                print(f"GT masks: {args.gt[i]}")
                gt = read_image(args.gt[i]).astype(np.float32)
            predictions, visualized_output = demo.run_on_image(img, class_names, gt, colors=kitti_colors)
            predictions, visualized_output_gt = demo.run_on_image(img, class_names, gt, use_gt=True, class_names_gt=class_names_gt, colors=kitti_colors_gt)
            logger.info(
                "{}: {} in {:.2f}s".format(
                    path,
                    "detected {} instances".format(len(predictions["instances"]))
                    if "instances" in predictions
                    else "finished",
                    time.time() - start_time,
                )
            )

            if args.output:
                if os.path.isdir(args.output):
                    assert os.path.isdir(args.output), args.output
                    filename = path.split(".")[0] + f"_{cfg.MODEL.CLIP_ADAPTER.TYPE}.png"
                    filename = filename.split("/")[-1]

                    filename_gt = path.split(".")[0] + f"_gt.png"
                    filename_gt = filename_gt.split("/")[-1]
                    
                    out_filename = os.path.join(args.output, filename)
                    out_gt_filename = os.path.join(args.output, filename_gt)
                else:
                    assert len(args.input) == 1, "Please specify a directory with args.output"
                    out_filename = args.output
                print(f"Saving the file as: {out_filename} and {out_gt_filename}")
                visualized_output.save(out_filename)
                visualized_output_gt.save(out_gt_filename)
            else:
                cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
                cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
                if cv2.waitKey(0) == 27:
                    break  # esc to quit
    else:
        raise NotImplementedError