# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
MaskFormer Training Script.

This script is a simplified version of the training script in detectron2/tools.
"""
try:
    # ignore ShapelyDeprecationWarning from fvcore
    import warnings

    from shapely.errors import ShapelyDeprecationWarning
    warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
except:
    pass

import argparse
import copy
import itertools
import json
import logging
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'


import shutil
import sys
from collections import OrderedDict
from glob import glob
from typing import Any, Dict, List, Set

import detectron2.data.transforms as T
import detectron2.utils.comm as comm
import numpy as np
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import (DatasetCatalog, MetadataCatalog,
                             build_detection_test_loader,
                             build_detection_train_loader)
from detectron2.data.datasets import load_sem_seg
from detectron2.data.detection_utils import read_image
from detectron2.engine import DefaultTrainer, HookBase, default_setup, launch
from detectron2.engine.defaults import DefaultPredictor
from detectron2.evaluation import (CityscapesInstanceEvaluator,
                                   CityscapesSemSegEvaluator, COCOEvaluator,
                                   COCOPanopticEvaluator, DatasetEvaluators,
                                   LVISEvaluator, SemSegEvaluator,
                                   verify_results)
from detectron2.modeling import build_model
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import ColorMode, Visualizer
from PIL import Image

# MaskFormer
from mask2former import (COCOInstanceNewBaselineDatasetMapper,
                         COCOPanopticNewBaselineDatasetMapper,
                         DropoutSemanticDatasetMapper, InstanceSegEvaluator,
                         MaskFormerInstanceDatasetMapper,
                         MaskFormerPanopticDatasetMapper,
                         MaskFormerSemanticDatasetMapper,
                         SemanticSegmentorWithTTA, add_dropout_config,
                         add_maskformer2_config, add_ttt_config)

KITTI_STEP_SEM_SEG_CATEGORIES = [
    {"name": "road", "id": 0, "trainId": 0},
    {"name": "sidewalk", "id": 1, "trainId": 1},
    {"name": "building", "id": 2, "trainId": 2},
    {"name": "wall", "id": 3, "trainId": 3},
    {"name": "fence", "id": 4, "trainId": 4},
    {"name": "pole", "id": 5, "trainId": 5},
    {"name": "traffic light", "id": 6, "trainId": 6},
    {"name": "traffic sign", "id": 7, "trainId": 7},
    {"name": "vegetation", "id": 8, "trainId": 8},
    {"name": "terrain", "id": 9, "trainId": 9},
    {"name": "sky", "id": 10, "trainId": 10},
    {"name": "person", "id": 11, "trainId": 11},
    {"name": "rider", "id": 12, "trainId": 12},
    {"name": "car", "id": 13, "trainId": 13},
    {"name": "truck", "id": 14, "trainId": 14},
    {"name": "bus", "id": 15, "trainId": 15},
    {"name": "train", "id": 16, "trainId": 16},
    {"name": "motorcycle", "id": 17, "trainId": 17},
    {"name": "bicycle", "id": 18, "trainId": 18},
]


def get_parser(epilog=None):
    """
    Create a parser with some common arguments used by detectron2 users.

    Args:
        epilog (str): epilog passed to ArgumentParser describing the usage.

    Returns:
        argparse.ArgumentParser:
    """
    parser = argparse.ArgumentParser(
        epilog=epilog
        or f"""
            Examples:

            Run on single machine:
                $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml

            Change some config options:
                $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001

            Run on multiple machines:
                (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
                (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
            """,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Whether to attempt to resume from the checkpoint directory. "
        "See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
    )
    parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
    parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
    parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
    parser.add_argument(
        "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
    )

    # PyTorch still may leave orphan processes in multi-gpu training.
    # Therefore we use a deterministic way to obtain port,
    # so that users are aware of orphan processes by seeing the port occupied.
    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
    parser.add_argument(
        "--dist-url",
        default="tcp://127.0.0.1:{}".format(port),
        help="initialization URL for pytorch distributed backend. See "
        "https://pytorch.org/docs/stable/distributed.html for details.",
    )
    parser.add_argument(
        "opts",
        help="""
            Modify config options at the end of the command. For Yacs configs, use
            space-separated "PATH.KEY VALUE" pairs.
            For python-based LazyConfig, use "path.key=value".
        """.strip(),
        default=None,
        nargs=argparse.REMAINDER,
    )

    # TTT arguments
    parser.add_argument(
        "--ttt_in_dir",
        type=str,
        default="",
        help=""
    )

    parser.add_argument(
        "--ttt_out_dir",
        type=str,
        default="",
        help=""
    )

    parser.add_argument(
        '--exp_dir',
        type=str,
        default=None,
        help='Experiment directory to save logs'
    )

    parser.add_argument(
        "--ttt_topl",
        type=float,
        default=None,
        help="Top fraction of confident pixels"
    )

    parser.add_argument(
        "--ttt_setting",
        type=str,
        default=None,
        help="Online or standard"
        
    )

    parser.add_argument(
        "--st_iters",
        type=int,
        default=None,
        help="Number of iterations of self-training for each image"
    )

    parser.add_argument(
        "--win_size",
        type=str,
        default=None,
    )


    # Dropout augmentation arguments
    parser.add_argument(
        "--drop_aug",
        action="store_true"
    )

    parser.add_argument(
        "--drop_ratio",
        type=float,
        default=None,
    )

    parser.add_argument(
        "--mask_type",
        type=str,
        default=None
    )

    return parser


def _get_kitti_step_meta():
    stuff_ids = [k["id"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]

    # For semantic segmentation, this mapping maps from contiguous stuff id
    # (in [0, 91], used in models) to ids in the dataset (used for processing results)
    stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
    stuff_classes = [k["name"] for k in KITTI_STEP_SEM_SEG_CATEGORIES]

    ret = {
        "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
        "stuff_classes": stuff_classes,
    }
    return ret


def _get_kitti_step_files(image_dir, queue, it):
    files = []
    for img_idx in queue:
        img_root = format(img_idx, "06d")
        label_file = os.path.join(image_dir, img_root + "_" + str(it) + ".png")

        image_file = os.path.join(image_dir, img_root + '.png')
        json_file = os.path.join(image_dir, img_root + '.json')
        conf_file = os.path.join(image_dir, img_root + '_conf_' + str(it) + '.npy')

        files.append((image_file, label_file, json_file, conf_file))
    assert len(files), "No images found in {}".format(image_dir)
    for f in files[0]:
        assert PathManager.isfile(f), f
    return files


def load_kitti_step_semantic(image_dir, queue, it):
    """
    Args:
        image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
        gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
    Returns:
        list[dict]: a list of dict, each has "file_name" and
            "sem_seg_file_name".
    """
    ret = []
    # gt_dir is small and contain many small files. make sense to fetch to local first
    # ref_dir = PathManager.get_local_path(ref_dir)
    for image_file, label_file, json_file, conf_file in _get_kitti_step_files(image_dir, queue, it):
        with PathManager.open(json_file, "r") as f:
            jsonobj = json.load(f)
        ret.append(
            {
                "file_name": image_file,
                "sem_seg_file_name": label_file,
                "conf_file_name": conf_file,
                "height": jsonobj["imgHeight"],
                "width": jsonobj["imgWidth"],
            }
        )
    assert len(ret), f"No images found in {image_dir}!"

    return ret


def _get_kitti_step_val_files(image_dir, ref_dir, img_idx):
    # import ipdb; ipdb.set_trace()

    img_root = format(img_idx, "06d")

    files = []
        
    label_file = os.path.join(ref_dir, img_root + "_sem.png")
    image_file = os.path.join(image_dir, img_root + '.png')
    json_file = os.path.join(ref_dir, img_root + '.json')

    files.append((image_file, label_file, json_file))
    assert len(files), "No images found in {}".format(image_dir)
    for f in files[0]:
        assert PathManager.isfile(f), f
    return files


def load_kitti_video_eval(image_dir, ref_dir, img_idx):
    """
    Args:
        image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
        gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
    Returns:
        list[dict]: a list of dict, each has "file_name" and
            "sem_seg_file_name".
    """
    ret = []
    # gt_dir is small and contain many small files. make sense to fetch to local first
    ref_dir = PathManager.get_local_path(ref_dir)
    for image_file, label_file, json_file, in _get_kitti_step_val_files(image_dir, ref_dir, img_idx):
        with PathManager.open(json_file, "r") as f:
            jsonobj = json.load(f)
        ret.append(
            {
                "file_name": image_file,
                "sem_seg_file_name": label_file,
                "height": jsonobj["imgHeight"],
                "width": jsonobj["imgWidth"],
            }
        )
    assert len(ret), f"No images found in {image_dir}!"
    # assert PathManager.isfile(
    #     ret[0]["sem_seg_file_name"]
    # ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py"  # noqa

    return ret    


def register_custom_kitti(ref_dir, queue, it):
    # root = os.path.join(root, "kitti_step")
    meta = _get_kitti_step_meta()

    image_dir = os.path.join(ref_dir, "train")
    gt_dir = os.path.join(ref_dir, "val")
    
    # Register training set
    # image_dir = os.path.join(root, "images", dirname)
    # gt_dir = os.path.join(root, "images", dirname)
    # train_ref_dir = os.path.join(ref_dir, dirname)
    name = f"kitti_step_video_sem_seg_train"
    DatasetCatalog.register(
        name, lambda x=image_dir, it=it: load_kitti_step_semantic(x, queue, it)
    )
    MetadataCatalog.get(name).set(
        stuff_classes=meta["stuff_classes"][:],
        # image_root=image_dir,
        # sem_seg_root=train_ref_dir,
        evaluator_type="sem_seg",
        ignore_label=255,  # NOTE: gt is saved in 16-bit TIFF images
    )
    
    # Register val set
    # image_dir = os.path.join(root, "images", "train")
    name = f"kitti_step_video_sem_seg_val"
    DatasetCatalog.register(
        name, lambda x=image_dir, y=gt_dir: load_kitti_video_eval(x, y, queue[-1])
    )
    MetadataCatalog.get(name).set(
        stuff_classes=meta["stuff_classes"][:],
        # image_root=image_dir,
        # sem_seg_root=ref_dir,
        evaluator_type="sem_seg",
        ignore_label=255,  # NOTE: gt is saved in 16-bit TIFF images
    )


class SelfTrainingPredictor(DefaultPredictor):
    def __init__(self, cfg):
        self.cfg = cfg.clone()  # cfg can be modified by model
        if len(cfg.DATASETS.TEST):
            self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
        

        # Instance mode NEEDS to be SEGMENTATION
        import matplotlib.pyplot as plt
        cmap = plt.cm.get_cmap('rainbow', len(self.metadata.stuff_classes))
        self.metadata.stuff_colors = {}
        for i, c in enumerate(self.metadata.stuff_classes):
            col = (cmap(i)[0] * 255, cmap(i)[1] * 255, cmap(i)[2] * 255)
            self.metadata.stuff_colors[i] = col
        
            # ONLY FOR VISUALIZATION OF UNCONFIDENT
            if c == "unconfident":
                # self.metadata.stuff_colors[i] = (255., 255., 255.)
                self.metadata.stuff_colors[i] = (0., 0., 0.)

        # checkpointer = DetectionCheckpointer(self.model)
        # checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format


    def __call__(self, original_image, model):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
        Returns:
            predictions (dict):
        """
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            image = self.aug.get_transform(original_image).apply_image(original_image)
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

            inputs = {"image": image, "height": height, "width": width}
            predictions = model([inputs])[0]

            return predictions
        
    def draw_predictions(self, original_image, preds):
        # Convert image from OpenCV BGR format to Matplotlib RGB format.
        if self.input_format == "BGR":
            # whether the model expects BGR inputs or RGB
            original_image = original_image[:, :, ::-1]
        
        visualizer = Visualizer(original_image, self.metadata, instance_mode=ColorMode.SEGMENTATION)
        vis_output = visualizer.draw_sem_seg(
            preds.argmax(dim=0).cpu()
        )

        return vis_output


class SelfTrainingDataloader(HookBase):
    def __init__(self, registration, cfg):
        """
        Each argument is a function that takes one argument: the trainer.
        """
        self.registration = registration
        self.cfg = cfg.clone()
        win_size = cfg.TTT.WIN_SIZE
        self.win_size = "inf" if win_size == "inf" else int(win_size)
        self.predictor = SelfTrainingPredictor(cfg)

        self.ttt_in = cfg.TTT.IN_DIR
        self.video_dir = cfg.TTT.IN_DIR.split('/')[-1]
        # idx = cfg.TTT.OUT_DIR.find('st_video') + len('st_video')
        self.ttt_out = os.path.join(cfg.TTT.OUT_DIR, "train")
        self.exp_dir = cfg.TTT.EXP_DIR
        os.makedirs(self.exp_dir, exist_ok=True)

        # Make the log/save directory for this experiment
        self.mask_type = cfg.DROPOUT_AUG.MASK_TYPE
        self.topl = cfg.TTT.TOPL
        self.uratio = str(cfg.TTT.TOPL)
        self.mratio = str(cfg.DROPOUT_AUG.RATIO)
        self.exp_dir = os.path.join(self.exp_dir,
                                    str(self.video_dir) + "_" + self.mask_type + 
                                        "_used" + self.uratio + "_mask" + self.mratio
                                    )
        os.makedirs(self.exp_dir, exist_ok=True)
        self.exp_dir = os.path.join(self.exp_dir, str(self.win_size) + "_win")
        # self.exp_dir = os.path.join(self.exp_dir, "inf_win")
        os.makedirs(self.exp_dir, exist_ok=True)

        # Remove log if exists
        exp_log = os.path.join(self.exp_dir, 'performance.txt')
        if os.path.isfile(exp_log):
            os.remove(exp_log)

        # Setting
        self.ttt_setting = cfg.TTT.SETTING
        self.orig_model_weights = cfg.MODEL.WEIGHTS

        # self.imgs_in_queue = np.arange(cfg.TTT.START_IMAGE + 1).tolist()
        self.imgs_in_queue = [0]
        self.internal_iter = 0
        self.max_st_iters = cfg.TTT.ST_ITERS
        self.max_img_idx = len(glob(os.path.join(self.ttt_in, "*.png")))
        # self.max_img_idx = 10

        self._root = os.getenv("DETECTRON2_DATASETS")



    # Re-register dataloader
    def before_step(self):
        # import ipdb; ipdb.set_trace()

        for name in ["train", "val"]:
            DatasetCatalog.remove(f"kitti_step_video_sem_seg_{name}")
            MetadataCatalog.remove(f"kitti_step_video_sem_seg_{name}")
        
        # import ipdb; ipdb.set_trace()
        self.registration(self.cfg.TTT.OUT_DIR, self.imgs_in_queue,
                            self.internal_iter)
        self.trainer._trainer._data_loader_iter = iter(self.trainer.build_train_loader(self.cfg,
                                                                            self.imgs_in_queue,
                                                                            self.internal_iter))

        # import ipdb; ipdb.set_trace()
        self.trainer.model.train()


    # Update model predictions on this image
    def after_step(self):
        # import ipdb; ipdb.set_trace()

        # Read image
        if ((self.internal_iter + 1) != self.max_st_iters):
            self.trainer.model.eval()
            for img_idx in self.imgs_in_queue:
                img = read_image(os.path.join(self.ttt_in, format(img_idx, "06d") + '.png'), format="BGR")

                # Run Detectron2 predictor
                predictions = self.predictor(img, self.trainer.model)

                preds = predictions["sem_seg"]
                label_idx = torch.argmax(preds, dim=0, keepdim=True)

                ####################### Logits or probs ################################

                # Logits
                # Do nothing

                # Probs
                preds = torch.nn.functional.softmax(preds, dim=0)

                ####################### Absolute or fractional ##########################

                # # Absolute
                # mask = (torch.gather(preds, 0, label_idx).squeeze() < self.topl)
                
                # Fractional
                # Take top self.topl labels
                probs = torch.gather(preds, 0, label_idx).squeeze()
                rows, cols = probs.shape
                topl = int((rows * cols) * self.topl)
                best_probs = torch.argsort(probs.flatten(), descending=True)[:topl]
                row_idx = best_probs // cols
                col_idx = best_probs % cols
                mask = torch.ones_like(probs, dtype=torch.long, device=label_idx.device)
                mask[row_idx, col_idx] = 0

                
                # Save confidence values for confidence-weighted thresholding
                conf_save_path = os.path.join(self.ttt_out, 
                                        format(img_idx, '06d') + '_conf_' + 
                                            str((self.internal_iter + 1) % self.max_st_iters) + '.npy')
                np.save(conf_save_path,
                        probs.cpu().numpy())


                # Filter labels with mask
                # Set 1 to null label (255 for KITTI_STEP)
                label_idx = label_idx.squeeze()
                label_idx[mask] = 255

                # Save label_idx into RGB image, with R = category_id
                label_im = Image.fromarray(label_idx.cpu().numpy().astype(np.uint8))
                save_dir = os.path.join(self.ttt_out, 
                                        format(img_idx, '06d') + 
                                            '_' + str((self.internal_iter + 1) % self.max_st_iters) + '.png')
                label_im.save(save_dir)

        if ((self.internal_iter + 1) == self.max_st_iters):
            # Queue is full and we have finished self-training --> eval on most recent image
            if self.win_size == "inf" or len(self.imgs_in_queue) == self.win_size:
            # if len(self.imgs_in_queue) > 0:
                # NEED TO ALSO LOG THE PERFORMANCE ON THIS IMAGE
                # logger.log(performance)
                # import ipdb; ipdb.set_trace()   # Check dataset val set right now
                res = self.trainer.test(self.cfg, self.trainer.model)

                # Log results to txt file
                if 'sem_seg' in res:                    # Same as comm.is_main_process
                    miou = str(res['sem_seg']['mIoU'])
                    # Look for experiment log and append
                    exp_log = os.path.join(self.exp_dir, 'performance.txt')
                    with open(exp_log, 'a') as fp:
                        fp.write(miou + '\n')
                    
                    # import ipdb; ipdb.set_trace()

                    # Save final predictions
                    img = read_image(os.path.join(self.ttt_in, format(self.imgs_in_queue[-1], "06d") + '.png'), format="BGR")
                    # Run Detectron2 predictor
                    self.trainer.model.eval()
                    predictions = self.predictor(img, self.trainer.model)

                    preds = predictions["sem_seg"]
                    vis_out = self.predictor.draw_predictions(img, preds)
                    save_path = os.path.join(self.exp_dir, 
                                                format(self.imgs_in_queue[-1], '06d') + '_final.png')
                    vis_out.save(save_path)

                # # Save this model
                # self.trainer.checkpointer.save("checkpoint_" + str(self.imgs_in_queue[-1]))

            # Modify queue as necessary
            self.imgs_in_queue.append(self.imgs_in_queue[-1] + 1)
            # We will take one more image on the next iteration
            # self.cfg.SOLVER.IMS_PER_BATCH += 1
            if self.win_size != "inf" and len(self.imgs_in_queue) > self.win_size:
            #     # import ipdb; ipdb.set_trace()
            #     # # Unless we need to evict
            #     # self.cfg.SOLVER.IMS_PER_BATCH -= 1
                # Evict
                self.imgs_in_queue.pop(0)
            self.internal_iter = 0
            
            # If we are online, do nothing
            # If we are standard, we need to reset the model
            if self.ttt_setting == "standard":
                self.trainer.checkpointer.load(self.orig_model_weights)

        else:
            self.internal_iter = self.internal_iter + 1
        
        if self.imgs_in_queue[-1] == self.max_img_idx:
            # This is why I hate Detectron2: raising an exception to implement early stopping L.O.L.
            # Perform after_train and then exit
            self.after_train()
            sys.exit(0)

    

    def after_train(self):
        # Copy models
        self.trainer.checkpointer.save("model_final_" +
                                            str(self.video_dir) + "_" + self.mask_type + 
                                                "_used" + self.uratio + "_mask" + self.mratio
                                      )



class Trainer(DefaultTrainer):
    """
    Extension of the Trainer class adapted to MaskFormer.
    """

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each
        builtin dataset. For your own dataset, you can simply create an
        evaluator manually in your script and do not have to worry about the
        hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []

        # import ipdb; ipdb.set_trace()

        evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
        # semantic segmentation
        if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
            evaluator_list.append(
                SemSegEvaluator(
                    dataset_name,
                    distributed=True,
                    output_dir=output_folder,
                )
            )
        # instance segmentation
        if evaluator_type == "coco":
            evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
        # panoptic segmentation
        if evaluator_type in [
            "coco_panoptic_seg",
            "ade20k_panoptic_seg",
            "cityscapes_panoptic_seg",
            "mapillary_vistas_panoptic_seg",
        ]:
            if cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON:
                evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
        # COCO
        if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
            evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
        if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
            evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
        # Mapillary Vistas
        if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
            evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
        if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
            evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
        # Cityscapes
        if evaluator_type == "cityscapes_instance":
            assert (
                torch.cuda.device_count() > comm.get_rank()
            ), "CityscapesEvaluator currently do not work with multiple machines."
            return CityscapesInstanceEvaluator(dataset_name)
        if evaluator_type == "cityscapes_sem_seg":
            assert (
                torch.cuda.device_count() > comm.get_rank()
            ), "CityscapesEvaluator currently do not work with multiple machines."
            return CityscapesSemSegEvaluator(dataset_name)
        if evaluator_type == "cityscapes_panoptic_seg":
            if cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
                assert (
                    torch.cuda.device_count() > comm.get_rank()
                ), "CityscapesEvaluator currently do not work with multiple machines."
                evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
            if cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
                assert (
                    torch.cuda.device_count() > comm.get_rank()
                ), "CityscapesEvaluator currently do not work with multiple machines."
                evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
        # ADE20K
        if evaluator_type == "ade20k_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
            evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
        # LVIS
        if evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, output_dir=output_folder)
        if len(evaluator_list) == 0:
            raise NotImplementedError(
                "no Evaluator for the dataset {} with the type {}".format(
                    dataset_name, evaluator_type
                )
            )
        elif len(evaluator_list) == 1:
            return evaluator_list[0]
        return DatasetEvaluators(evaluator_list)

    @classmethod
    def build_train_loader(cls, cfg, queue=[0], it=0):
        # # Edits for KITTI_STEP
        # import ipdb; ipdb.set_trace()

        for name in ["train", "val"]:
            DatasetCatalog.remove(f"kitti_step_video_sem_seg_{name}")
            MetadataCatalog.remove(f"kitti_step_video_sem_seg_{name}")
        
        # video_dir = cfg.TTT.OUT_DIR.split('/')[-1]

        # root = os.getenv("DETECTRON2_DATASETS")
        cls.queue = queue
        register_custom_kitti(cfg.TTT.OUT_DIR, queue, it)

        # Semantic segmentation dataset mapper
        if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
            mapper = MaskFormerSemanticDatasetMapper(cfg, True)
            return build_detection_train_loader(cfg, mapper=mapper)
        # Dropout ST dataset mapper
        elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_dropout_st":
            mapper = DropoutSemanticDatasetMapper(cfg, True)
            return build_detection_train_loader(cfg, mapper=mapper)
        # Panoptic segmentation dataset mapper
        elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
            mapper = MaskFormerPanopticDatasetMapper(cfg, True)
            return build_detection_train_loader(cfg, mapper=mapper)
        # Instance segmentation dataset mapper
        elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance":
            mapper = MaskFormerInstanceDatasetMapper(cfg, True)
            return build_detection_train_loader(cfg, mapper=mapper)
        # coco instance segmentation lsj new baseline
        elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj":
            mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
            return build_detection_train_loader(cfg, mapper=mapper)
        # coco panoptic segmentation lsj new baseline
        elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
            mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
            return build_detection_train_loader(cfg, mapper=mapper)
        else:
            mapper = None
            return build_detection_train_loader(cfg, mapper=mapper)


    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        for name in ["train", "val"]:
            DatasetCatalog.remove(f"kitti_step_video_sem_seg_{name}")
            MetadataCatalog.remove(f"kitti_step_video_sem_seg_{name}")
        
        # import ipdb; ipdb.set_trace()
        # video_dir = cfg.TTT.OUT_DIR.split('/')[-1]

        # root = os.getenv("DETECTRON2_DATASETS")
        register_custom_kitti(cfg.TTT.OUT_DIR, cls.queue, 0)

        return build_detection_test_loader(cfg, dataset_name)


    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_optimizer(cls, cfg, model):
        weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
        weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED

        defaults = {}
        defaults["lr"] = cfg.SOLVER.BASE_LR
        defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY

        norm_module_types = (
            torch.nn.BatchNorm1d,
            torch.nn.BatchNorm2d,
            torch.nn.BatchNorm3d,
            torch.nn.SyncBatchNorm,
            # NaiveSyncBatchNorm inherits from BatchNorm2d
            torch.nn.GroupNorm,
            torch.nn.InstanceNorm1d,
            torch.nn.InstanceNorm2d,
            torch.nn.InstanceNorm3d,
            torch.nn.LayerNorm,
            torch.nn.LocalResponseNorm,
        )

        params: List[Dict[str, Any]] = []
        memo: Set[torch.nn.parameter.Parameter] = set()
        for module_name, module in model.named_modules():
            for module_param_name, value in module.named_parameters(recurse=False):
                if not value.requires_grad:
                    continue
                # Avoid duplicating parameters
                if value in memo:
                    continue
                memo.add(value)

                hyperparams = copy.copy(defaults)
                if "backbone" in module_name:
                    hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
                if (
                    "relative_position_bias_table" in module_param_name
                    or "absolute_pos_embed" in module_param_name
                ):
                    print(module_param_name)
                    hyperparams["weight_decay"] = 0.0
                if isinstance(module, norm_module_types):
                    hyperparams["weight_decay"] = weight_decay_norm
                if isinstance(module, torch.nn.Embedding):
                    hyperparams["weight_decay"] = weight_decay_embed
                params.append({"params": [value], **hyperparams})

        def maybe_add_full_model_gradient_clipping(optim):
            # detectron2 doesn't have full model gradient clipping now
            clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
            enable = (
                cfg.SOLVER.CLIP_GRADIENTS.ENABLED
                and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
                and clip_norm_val > 0.0
            )

            class FullModelGradientClippingOptimizer(optim):
                def step(self, closure=None):
                    all_params = itertools.chain(*[x["params"] for x in self.param_groups])
                    torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
                    super().step(closure=closure)

            return FullModelGradientClippingOptimizer if enable else optim

        optimizer_type = cfg.SOLVER.OPTIMIZER
        if optimizer_type == "SGD":
            optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
                params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
            )
        elif optimizer_type == "ADAMW":
            optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
                params, cfg.SOLVER.BASE_LR
            )
        else:
            raise NotImplementedError(f"no optimizer type {optimizer_type}")
        if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
            optimizer = maybe_add_gradient_clipping(cfg, optimizer)
        return optimizer

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("detectron2.trainer")
        # In the end of training, run an evaluation with TTA.
        logger.info("Running inference with test-time augmentation ...")
        model = SemanticSegmentorWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    # for poly lr schedule
    add_deeplab_config(cfg)
    add_maskformer2_config(cfg)
    add_ttt_config(cfg)
    add_dropout_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    cfg.TTT.IN_DIR = args.ttt_in_dir
    cfg.TTT.OUT_DIR = args.ttt_out_dir
    cfg.TTT.ST_ITERS = args.st_iters
    cfg.TTT.WIN_SIZE = args.win_size
    cfg.TTT.TOPL = args.ttt_topl
    cfg.TTT.SETTING = args.ttt_setting
    cfg.TTT.EXP_DIR = args.exp_dir

    cfg.DROPOUT_AUG.ENABLED = args.drop_aug
    cfg.DROPOUT_AUG.RATIO = args.drop_ratio
    cfg.DROPOUT_AUG.MASK_TYPE = args.mask_type
    # cfg.DROPOUT_AUG.CONF_MAP = None

    # import ipdb; ipdb.set_trace()
    # # Compute number of images
    # for name in ["train", "val"]:
    #     DatasetCatalog.remove(f"kitti_step_video_sem_seg_half_{name}")
    #     MetadataCatalog.remove(f"kitti_step_video_sem_seg_half_{name}")

    # _root = os.getenv("DETECTRON2_DATASETS")
    # vid = cfg.TTT.IN_DIR.split('/')[-1]
    # register_kitti_step_video_sem_seg_half(_root, vid, True)
    # if cfg.DATASETS.TRAIN[0] == 'kitti_step_video_sem_seg_half_train':
    #     meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
    #     cfg.TTT.START_IMAGE = meta.num_train_images
    
    cfg.DATALOADER.NUM_WORKERS = cfg.SOLVER.IMS_PER_BATCH // 2


    # DEBUGGING
    cfg.DATALOADER.NUM_WORKERS = 0


    # cfg.freeze()
    default_setup(cfg, args)
    # Setup logger for "mask_former" module
    setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="mask2former")

    return cfg


def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )

        # import ipdb; ipdb.set_trace()

        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    # import ipdb; ipdb.set_trace()

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    trainer.register_hooks(
        [SelfTrainingDataloader(register_custom_kitti, cfg)]
    )

    return trainer.train()


if __name__ == "__main__":
    # args = default_argument_parser().parse_args()
    args = get_parser().parse_args()

    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )
