import os
import sys
from glob import glob

import detectron2.data.transforms as T
import numpy as np
import torch
from detectron2.data import (DatasetCatalog, MetadataCatalog,
                             build_detection_test_loader,
                             build_detection_train_loader)
from detectron2.data.detection_utils import read_image
from detectron2.engine import DefaultTrainer, HookBase, default_setup, launch
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
from PIL import Image


class SelfTrainingPredictor(DefaultPredictor):
    def __init__(self, cfg):
        self.cfg = cfg.clone()  # cfg can be modified by model
        if len(cfg.DATASETS.TEST):
            self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
        

        # Instance mode NEEDS to be SEGMENTATION
        import matplotlib.pyplot as plt
        cmap = plt.cm.get_cmap('rainbow', len(self.metadata.stuff_classes))
        self.metadata.stuff_colors = {}
        for i, c in enumerate(self.metadata.stuff_classes):
            col = (cmap(i)[0] * 255, cmap(i)[1] * 255, cmap(i)[2] * 255)
            self.metadata.stuff_colors[i] = col
        
            # ONLY FOR VISUALIZATION OF UNCONFIDENT
            if c == "unconfident":
                # self.metadata.stuff_colors[i] = (255., 255., 255.)
                self.metadata.stuff_colors[i] = (0., 0., 0.)

        # checkpointer = DetectionCheckpointer(self.model)
        # checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format


    def __call__(self, original_image, model):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
        Returns:
            predictions (dict):
        """
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            image = self.aug.get_transform(original_image).apply_image(original_image)
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

            # Concatenate mask of 1s
            C, H, W = image.shape
            mask = torch.ones((1, H, W)).type(image.dtype)
            input_image = torch.cat([image, mask], dim=0)

            inputs = {"image": input_image, "height": height, "width": width}
            predictions = model([inputs])[0]

            return predictions
        
    def draw_predictions(self, original_image, preds):
        # Convert image from OpenCV BGR format to Matplotlib RGB format.
        if self.input_format == "BGR":
            # whether the model expects BGR inputs or RGB
            original_image = original_image[:, :, ::-1]
        
        visualizer = Visualizer(original_image, self.metadata, instance_mode=ColorMode.SEGMENTATION)
        vis_output = visualizer.draw_sem_seg(
            preds.argmax(dim=0).cpu()
        )

        return vis_output


class SelfTrainingDataloader(HookBase):
    def __init__(self, registration, cfg):
        """
        Each argument is a function that takes one argument: the trainer.
        """
        self.registration = registration
        self.cfg = cfg.clone()
        win_size = cfg.TTT.WIN_SIZE
        self.win_size = "inf" if win_size == "inf" else int(win_size)
        self.predictor = SelfTrainingPredictor(cfg)

        self.ttt_in = cfg.TTT.IN_DIR
        self.video_dir = cfg.TTT.IN_DIR.split('/')[-1]
        # idx = cfg.TTT.OUT_DIR.find('st_video') + len('st_video')
        self.ttt_out = os.path.join(cfg.TTT.OUT_DIR, "train")
        self.exp_dir = cfg.TTT.EXP_DIR
        os.makedirs(self.exp_dir, exist_ok=True)

        # Make the log/save directory for this experiment
        self.mask_type = cfg.DROPOUT_AUG.MASK_TYPE
        self.topl = cfg.TTT.TOPL
        self.uratio = str(cfg.TTT.TOPL)
        self.mratio = str(cfg.DROPOUT_AUG.RATIO)
        self.exp_dir = os.path.join(self.exp_dir,
                                    str(self.video_dir) + "_" + self.mask_type + 
                                        "_used" + self.uratio + "_mask" + self.mratio
                                    )
        os.makedirs(self.exp_dir, exist_ok=True)
        self.exp_dir = os.path.join(self.exp_dir, str(self.win_size) + "_win")
        # self.exp_dir = os.path.join(self.exp_dir, "inf_win")
        os.makedirs(self.exp_dir, exist_ok=True)

        # Remove log if exists
        exp_log = os.path.join(self.exp_dir, 'performance.txt')
        if os.path.isfile(exp_log):
            os.remove(exp_log)

        # Setting
        self.ttt_setting = cfg.TTT.SETTING
        self.orig_model_weights = cfg.MODEL.WEIGHTS

        # self.imgs_in_queue = np.arange(cfg.TTT.START_IMAGE + 1).tolist()
        self.imgs_in_queue = [0]
        self.internal_iter = 0
        self.max_st_iters = cfg.TTT.ST_ITERS
        self.max_img_idx = len(glob(os.path.join(self.ttt_in, "*.png")))
        # self.max_img_idx = 10

        self._root = os.getenv("DETECTRON2_DATASETS")



    # Re-register dataloader
    def before_step(self):
        # import ipdb; ipdb.set_trace()

        for name in ["train", "val"]:
            DatasetCatalog.remove(f"kitti_step_video_sem_seg_{name}")
            MetadataCatalog.remove(f"kitti_step_video_sem_seg_{name}")
        
        # import ipdb; ipdb.set_trace()
        self.registration(self.cfg.TTT.OUT_DIR, self.imgs_in_queue,
                            self.internal_iter)
        self.trainer._trainer._data_loader_iter = iter(self.trainer.build_train_loader(self.cfg,
                                                                            self.imgs_in_queue,
                                                                            self.internal_iter))

        # import ipdb; ipdb.set_trace()
        self.trainer.model.train()


    # Update model predictions on this image
    def after_step(self):
        # import ipdb; ipdb.set_trace()

        # Read image
        if ((self.internal_iter + 1) != self.max_st_iters):
            self.trainer.model.eval()
            for img_idx in self.imgs_in_queue:
                img = read_image(os.path.join(self.ttt_in, format(img_idx, "06d") + '.png'), format="BGR")

                # Run Detectron2 predictor
                predictions = self.predictor(img, self.trainer.model)

                preds = predictions["sem_seg"]
                label_idx = torch.argmax(preds, dim=0, keepdim=True)

                ####################### Logits or probs ################################

                # Logits
                # Do nothing

                # Probs
                preds = torch.nn.functional.softmax(preds, dim=0)

                ####################### Absolute or fractional ##########################

                # # Absolute
                # mask = (torch.gather(preds, 0, label_idx).squeeze() < self.topl)
                
                # Fractional
                # Take top self.topl labels
                probs = torch.gather(preds, 0, label_idx).squeeze()
                rows, cols = probs.shape
                topl = int((rows * cols) * self.topl)
                best_probs = torch.argsort(probs.flatten(), descending=True)[:topl]
                row_idx = best_probs // cols
                col_idx = best_probs % cols
                mask = torch.ones_like(probs, dtype=torch.long, device=label_idx.device)
                mask[row_idx, col_idx] = 0

                
                # Save confidence values for confidence-weighted thresholding
                conf_save_path = os.path.join(self.ttt_out, 
                                        format(img_idx, '06d') + '_conf_' + 
                                            str((self.internal_iter + 1) % self.max_st_iters) + '.npy')
                np.save(conf_save_path,
                        probs.cpu().numpy())


                # Filter labels with mask
                # Set 1 to null label (255 for KITTI_STEP)
                label_idx = label_idx.squeeze()
                label_idx[mask] = 255

                # Save label_idx into RGB image, with R = category_id
                label_im = Image.fromarray(label_idx.cpu().numpy().astype(np.uint8))
                save_dir = os.path.join(self.ttt_out, 
                                        format(img_idx, '06d') + 
                                            '_' + str((self.internal_iter + 1) % self.max_st_iters) + '.png')
                label_im.save(save_dir)

        if ((self.internal_iter + 1) == self.max_st_iters):
            # Queue is full and we have finished self-training --> eval on most recent image
            if self.win_size == "inf" or len(self.imgs_in_queue) == self.win_size:
            # if len(self.imgs_in_queue) > 0:
                # NEED TO ALSO LOG THE PERFORMANCE ON THIS IMAGE
                # logger.log(performance)
                # import ipdb; ipdb.set_trace()   # Check dataset val set right now
                res = self.trainer.test(self.cfg, self.trainer.model)

                # Log results to txt file
                if 'sem_seg' in res:                    # Same as comm.is_main_process
                    miou = str(res['sem_seg']['mIoU'])
                    # Look for experiment log and append
                    exp_log = os.path.join(self.exp_dir, 'performance.txt')
                    with open(exp_log, 'a') as fp:
                        fp.write(miou + '\n')
                    
                    # import ipdb; ipdb.set_trace()

                    # Save final predictions
                    img = read_image(os.path.join(self.ttt_in, format(self.imgs_in_queue[-1], "06d") + '.png'), format="BGR")
                    # Run Detectron2 predictor
                    self.trainer.model.eval()
                    predictions = self.predictor(img, self.trainer.model)

                    preds = predictions["sem_seg"]
                    vis_out = self.predictor.draw_predictions(img, preds)
                    save_path = os.path.join(self.exp_dir, 
                                                format(self.imgs_in_queue[-1], '06d') + '_final.png')
                    vis_out.save(save_path)

                # # Save this model
                # self.trainer.checkpointer.save("checkpoint_" + str(self.imgs_in_queue[-1]))

            # Modify queue as necessary
            self.imgs_in_queue.append(self.imgs_in_queue[-1] + 1)
            # We will take one more image on the next iteration
            # self.cfg.SOLVER.IMS_PER_BATCH += 1
            if self.win_size != "inf" and len(self.imgs_in_queue) > self.win_size:
            #     # import ipdb; ipdb.set_trace()
            #     # # Unless we need to evict
            #     # self.cfg.SOLVER.IMS_PER_BATCH -= 1
                # Evict
                self.imgs_in_queue.pop(0)
            self.internal_iter = 0
            
            # If we are online, do nothing
            # If we are standard, we need to reset the model
            if self.ttt_setting == "standard":
                self.trainer.checkpointer.load(self.orig_model_weights)

        else:
            self.internal_iter = self.internal_iter + 1
        
        if self.imgs_in_queue[-1] == self.max_img_idx:
            # This is why I hate Detectron2: raising an exception to implement early stopping L.O.L.
            # Perform after_train and then exit
            self.after_train()
            sys.exit(0)

    

    def after_train(self):
        # Copy models
        self.trainer.checkpointer.save("model_final_" +
                                            str(self.video_dir) + "_" + self.mask_type + 
                                                "_used" + self.uratio + "_mask" + self.mratio
                                      )

