import datetime
import logging
import os
import sys
import time
from collections import OrderedDict, abc
from contextlib import ExitStack, contextmanager
from glob import glob
from typing import List, Union

import detectron2.data.transforms as T
import numpy as np
import torch
from detectron2.data import (DatasetCatalog, MetadataCatalog,
                             build_detection_test_loader,
                             build_detection_train_loader)
from detectron2.data.detection_utils import read_image
from detectron2.engine import DefaultTrainer, HookBase, default_setup, launch
from detectron2.engine.defaults import DefaultPredictor
from detectron2.evaluation import (DatasetEvaluator, DatasetEvaluators,
                                   inference_context)
from detectron2.structures import BitMasks, Boxes, ImageList, Instances
from detectron2.utils.comm import get_world_size, is_main_process
from detectron2.utils.logger import log_every_n_seconds
from detectron2.utils.visualizer import ColorMode, Visualizer
from PIL import Image
from torch import nn


class MAEPredictor(DefaultPredictor):
    def __init__(self, cfg):
        self.cfg = cfg.clone()  # cfg can be modified by model
        if len(cfg.DATASETS.TEST):
            self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
        
        self.mask_ratio = cfg.PRETRAIN.MASK_RATIO[0]
        self.patch_size = cfg.MODEL.SWIN.PATCH_SIZE
        self.size_divisibility = cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY

        # Instance mode NEEDS to be SEGMENTATION
        import matplotlib.pyplot as plt
        cmap = plt.cm.get_cmap('rainbow', len(self.metadata.stuff_classes))
        self.metadata.stuff_colors = {}
        for i, c in enumerate(self.metadata.stuff_classes):
            col = (cmap(i)[0] * 255, cmap(i)[1] * 255, cmap(i)[2] * 255)
            self.metadata.stuff_colors[i] = col
        
            # ONLY FOR VISUALIZATION OF UNCONFIDENT
            if c == "unconfident":
                # self.metadata.stuff_colors[i] = (255., 255., 255.)
                self.metadata.stuff_colors[i] = (0., 0., 0.)

        # checkpointer = DetectionCheckpointer(self.model)
        # checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format


    def __call__(self, original_image, model):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
        Returns:
            predictions (dict):
        """
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            image = self.aug.get_transform(original_image).apply_image(original_image)
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

            images = [x for x in image]
            images = ImageList.from_tensors(images, self.size_divisibility)
            images = images.tensor

            # Generate random mask
            _, h, w = images.shape
            patched_img = self.patchify(images)

            _, seq_mask, _ = self.random_masking(patched_img, self.mask_ratio)

            mask = seq_mask.unsqueeze(-1).repeat_interleave(self.patch_size ** 2, dim=-1)
            unpatched_mask = self.unpatchify_mask(mask, h, w)

            inputs = {"image": image, "height": height, "width": width, "image_mask" : unpatched_mask}

            # inputs = {"image": image, "height": height, "width": width}
            predictions = model([inputs])[0]

            return predictions
    
    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        imgs = imgs.unsqueeze(0)
        p = self.patch_size
        assert imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0

        h = imgs.shape[2] // p
        w = imgs.shape[3] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
        
        return x
    
    def unpatchify(self, x, h, w):
        """
        x: (N, L, patch_size**2 * 3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_size
        assert (h // p) * (w // p) == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h // p, w // p, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h, w))

        return imgs.squeeze(0)
    
    def unpatchify_mask(self, x, h, w):
        """
        x: (N, L, patch_size**2 * 1)
        imgs: (N, 1, H, W)
        """
        p = self.patch_size
        assert (h // p) * (w // p) == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h // p, w // p, p, p, 1))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 1, h, w))

        return imgs.squeeze(0)
 
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """

        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 1 is keep, 0 is remove
        mask = torch.zeros([N, L], device=x.device)
        mask[:, :len_keep] = 1
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

       
    def draw_predictions(self, original_image, preds):
        # Convert image from OpenCV BGR format to Matplotlib RGB format.
        if self.input_format == "BGR":
            # whether the model expects BGR inputs or RGB
            original_image = original_image[:, :, ::-1]
        
        visualizer = Visualizer(original_image, self.metadata, instance_mode=ColorMode.SEGMENTATION)
        vis_output = visualizer.draw_sem_seg(
            preds.argmax(dim=0).cpu()
        )

        return vis_output
    
    def draw_reconstructions(self, recon):
        # import ipdb; ipdb.set_trace()

        recon = Image.fromarray(recon.permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8))

        return recon


class MAETestPredictor(MAEPredictor):
    def __init__(self, cfg):
        super().__init__(cfg=cfg)


    def custom_inference_on_dataset(
        self,
        save_dir,
        model, data_loader,
        evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None]
    ):
        """
        Run model on the data_loader and evaluate the metrics with evaluator.
        Also benchmark the inference speed of `model.__call__` accurately.
        The model will be used in eval mode.
        Args:
            model (callable): a callable which takes an object from
                `data_loader` and returns some outputs.
                If it's an nn.Module, it will be temporarily set to `eval` mode.
                If you wish to evaluate a model in `training` mode instead, you can
                wrap the given model and override its behavior of `.eval()` and `.train()`.
            data_loader: an iterable object with a length.
                The elements it generates will be the inputs to the model.
            evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
                but don't want to do any evaluation.
        Returns:
            The return value of `evaluator.evaluate()`
        """
        num_devices = get_world_size()
        logger = logging.getLogger(__name__)
        logger.info("Start inference on {} batches".format(len(data_loader)))

        total = len(data_loader)  # inference data loader must have a fixed length
        if evaluator is None:
            # create a no-op evaluator
            evaluator = DatasetEvaluators([])
        if isinstance(evaluator, abc.MutableSequence):
            evaluator = DatasetEvaluators(evaluator)
        evaluator.reset()

        num_warmup = min(5, total - 1)
        start_time = time.perf_counter()
        total_data_time = 0
        total_compute_time = 0
        total_eval_time = 0
        with ExitStack() as stack:
            if isinstance(model, nn.Module):
                stack.enter_context(inference_context(model))
            stack.enter_context(torch.no_grad())

            start_data_time = time.perf_counter()
            for idx, inputs in enumerate(data_loader):
                assert len(inputs) == 1, 'Can only use this inference code for 1 image at a time'

                image = inputs[0]['image']
                images = [x for x in image]
                images = ImageList.from_tensors(images, self.size_divisibility)
                images = images.tensor

                # Generate random mask
                _, h, w = images.shape
                patched_img = self.patchify(images)

                _, seq_mask, _ = self.random_masking(patched_img, self.mask_ratio)

                mask = seq_mask.unsqueeze(-1).repeat_interleave(self.patch_size ** 2, dim=-1)
                unpatched_mask = self.unpatchify_mask(mask, h, w)

                inputs[0]["image_mask"] = unpatched_mask

                total_data_time += time.perf_counter() - start_data_time
                if idx == num_warmup:
                    start_time = time.perf_counter()
                    total_data_time = 0
                    total_compute_time = 0
                    total_eval_time = 0

                start_compute_time = time.perf_counter()
                outputs = model(inputs)
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                total_compute_time += time.perf_counter() - start_compute_time

                # import ipdb; ipdb.set_trace()
                # Save outputs
                original_image = read_image(inputs[0]['file_name'], format="BGR")
                save_root = inputs[0]['file_name'].split('/')[-1].split('.png')[0]
                preds = outputs[0]["sem_seg"]
                vis_out = self.draw_predictions(original_image, preds)
                save_path = os.path.join(save_dir, save_root + '_final.png')
                vis_out.save(save_path)

                recon = outputs[0]["recon"]
                recon_preds = self.draw_reconstructions(recon)
                save_path = os.path.join(save_dir, save_root + '_recon.png')
                recon_preds.save(save_path)

                start_eval_time = time.perf_counter()
                evaluator.process(inputs, outputs)
                total_eval_time += time.perf_counter() - start_eval_time

                iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
                data_seconds_per_iter = total_data_time / iters_after_start
                compute_seconds_per_iter = total_compute_time / iters_after_start
                eval_seconds_per_iter = total_eval_time / iters_after_start
                total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
                if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
                    eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
                    log_every_n_seconds(
                        logging.INFO,
                        (
                            f"Inference done {idx + 1}/{total}. "
                            f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
                            f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
                            f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
                            f"Total: {total_seconds_per_iter:.4f} s/iter. "
                            f"ETA={eta}"
                        ),
                        n=5,
                    )
                start_data_time = time.perf_counter()

        # Measure the time only for this worker (before the synchronization barrier)
        total_time = time.perf_counter() - start_time
        total_time_str = str(datetime.timedelta(seconds=total_time))
        # NOTE this format is parsed by grep
        logger.info(
            "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
                total_time_str, total_time / (total - num_warmup), num_devices
            )
        )
        total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
        logger.info(
            "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
                total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
            )
        )

        results = evaluator.evaluate()
        # An evaluator may return None when not in main process.
        # Replace it by an empty dict instead to make it easier for downstream code to handle
        if results is None:
            results = {}
        return results

class MAEDataloader(HookBase):
    def __init__(self, registration, cfg):
        """
        Each argument is a function that takes one argument: the trainer.
        """
        self.registration = registration
        self.cfg = cfg.clone()
        win_size = cfg.TTT.WIN_SIZE
        self.win_size = "inf" if win_size == "inf" else int(win_size)
        self.predictor = MAEPredictor(cfg)
        self.restart_optimizer = cfg.TTT.RESTART_OPTIMIZER 
        self.ttt_in = cfg.TTT.IN_DIR
        self.video_dir = cfg.TTT.IN_DIR.split('/')[-1]
        # idx = cfg.TTT.OUT_DIR.find('st_video') + len('st_video')
        self.ttt_out = os.path.join(cfg.TTT.OUT_DIR, "train")
        self.exp_dir = cfg.TTT.EXP_DIR
        os.makedirs(self.exp_dir, exist_ok=True)

        # Make the log/save directory for this experiment
        self.mask_type = cfg.DROPOUT_AUG.MASK_TYPE
        self.topl = cfg.TTT.TOPL
        self.uratio = str(cfg.TTT.TOPL)
        self.mratio = str(cfg.DROPOUT_AUG.RATIO)
        self.exp_dir = os.path.join(self.exp_dir,
                                    str(self.video_dir) + "_" + self.mask_type + 
                                        "_used" + self.uratio + "_mask" + self.mratio
                                    )
        os.makedirs(self.exp_dir, exist_ok=True)
        self.exp_dir = os.path.join(self.exp_dir, str(self.win_size) + "_win")
        # self.exp_dir = os.path.join(self.exp_dir, "inf_win")
        os.makedirs(self.exp_dir, exist_ok=True)

        # Remove log if exists
        exp_log = os.path.join(self.exp_dir, 'performance.txt')
        if os.path.isfile(exp_log):
            os.remove(exp_log)

        # Setting
        self.ttt_setting = cfg.TTT.SETTING
        self.orig_model_weights = cfg.MODEL.WEIGHTS

        # self.imgs_in_queue = np.arange(cfg.TTT.START_IMAGE + 1).tolist()
        self.imgs_in_queue = [0]
        self.internal_iter = 0
        self.max_st_iters = cfg.TTT.ST_ITERS
        self.max_img_idx = len(glob(os.path.join(self.ttt_in, "*.png")))
        # self.max_img_idx = 10

        self._root = os.getenv("DETECTRON2_DATASETS")



    # Re-register dataloader
    def before_step(self):
        # import ipdb; ipdb.set_trace()

        for name in ["train", "val"]:
            DatasetCatalog.remove(f"kitti_step_video_sem_seg_{name}")
            MetadataCatalog.remove(f"kitti_step_video_sem_seg_{name}")
        
        # import ipdb; ipdb.set_trace()
        self.registration(self.cfg.TTT.OUT_DIR, self.imgs_in_queue,
                            self.internal_iter)
        # import ipdb; ipdb.set_trace()
        self.trainer._trainer._data_loader_iter = iter(self.trainer.build_train_loader(self.cfg,
                                                                            self.imgs_in_queue,
                                                                            self.internal_iter))

        # import ipdb; ipdb.set_trace()
        self.trainer.model.train()


    # Update model predictions on this image
    def after_step(self):
        # import ipdb; ipdb.set_trace()

        if ((self.internal_iter + 1) == self.max_st_iters):
            # Queue is full and we have finished self-training --> eval on most recent image
            if self.win_size == "inf" or len(self.imgs_in_queue) == self.win_size:
            # if len(self.imgs_in_queue) > 0:
                # NEED TO ALSO LOG THE PERFORMANCE ON THIS IMAGE
                # logger.log(performance)
                # import ipdb; ipdb.set_trace()   # Check dataset val set right now
                res = self.trainer.test(self.cfg, self.trainer.model)

                # Log results to txt file
                if 'sem_seg' in res:                    # Same as comm.is_main_process
                    miou = str(res['sem_seg']['mIoU'])
                    # Look for experiment log and append
                    exp_log = os.path.join(self.exp_dir, 'performance.txt')
                    with open(exp_log, 'a') as fp:
                        fp.write(miou + '\n')
                    
                    # import ipdb; ipdb.set_trace()

                    # Save final predictions
                    img = read_image(os.path.join(self.ttt_in, format(self.imgs_in_queue[-1], "06d") + '.png'), format="BGR")
                    # Run Detectron2 predictor
                    self.trainer.model.eval()
                    predictions = self.predictor(img, self.trainer.model)

                    preds = predictions["sem_seg"]
                    vis_out = self.predictor.draw_predictions(img, preds)
                    save_path = os.path.join(self.exp_dir, 
                                                format(self.imgs_in_queue[-1], '06d') + '_final.png')
                    vis_out.save(save_path)

                    recon = predictions["recon"]
                    recon_preds = self.predictor.draw_reconstructions(recon)
                    save_path = os.path.join(self.exp_dir, 
                                                format(self.imgs_in_queue[-1], '06d') + '_recon.png')
                    recon_preds.save(save_path)

                # # Save this model
                # self.trainer.checkpointer.save("checkpoint_" + str(self.imgs_in_queue[-1]))

            # Modify queue as necessary
            self.imgs_in_queue.append(self.imgs_in_queue[-1] + 1)
            # We will take one more image on the next iteration
            # self.cfg.SOLVER.IMS_PER_BATCH += 1
            if self.win_size != "inf" and len(self.imgs_in_queue) > self.win_size:
            #     # import ipdb; ipdb.set_trace()
            #     # # Unless we need to evict
            #     # self.cfg.SOLVER.IMS_PER_BATCH -= 1
                # Evict
                self.imgs_in_queue.pop(0)
            self.internal_iter = 0
            
            # If we are online, do nothing
            # If we are standard, we need to reset the model
            if self.ttt_setting == "standard":
                self.trainer.checkpointer.load(self.orig_model_weights)
            if self.restart_optimizer:
                self.trainer.restart_optimizer()

        else:
            self.internal_iter = self.internal_iter + 1
        
        if self.imgs_in_queue[-1] == self.max_img_idx:
            # This is why I hate Detectron2: raising an exception to implement early stopping L.O.L.
            # Perform after_train and then exit
            self.after_train()
            sys.exit(0)

    

    def after_train(self):
        # Copy models
        self.trainer.checkpointer.save("model_final_" +
                                            str(self.video_dir) + "_" + self.mask_type + 
                                                "_used" + self.uratio + "_mask" + self.mratio
                                      )
