# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import numpy as np
from mmcv.runner import HOOKS, BaseRunner
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook


@HOOKS.register_module()
class MMClsWandbHook(WandbLoggerHook):
    """Enhanced Wandb logger hook for classification.

    Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
    only automatically log all information in ``log_buffer`` but also log
    the following extra information.

    - **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at
      every checkpoint interval will be saved as W&B Artifacts. This depends on
      the : class:`mmcv.runner.CheckpointHook` whose priority is higher than
      this hook. Please refer to
      https://docs.wandb.ai/guides/artifacts/model-versioning to learn more
      about model versioning with W&B Artifacts.

    - **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every
      checkpoint artifact will have a metadata associated with it. The metadata
      contains the evaluation metrics computed on validation data with that
      checkpoint along with the current epoch/iter. It depends on
      :class:`EvalHook` whose priority is higher than this hook.

    - **Evaluation**: At every interval, this hook logs the model prediction as
      interactive W&B Tables. The number of samples logged is given by
      ``num_eval_images``. Currently, this hook logs the predicted labels along
      with the ground truth at every evaluation interval. This depends on the
      :class:`EvalHook` whose priority is higher than this hook. Also note that
      the data is just logged once and subsequent evaluation tables uses
      reference to the logged data to save memory usage. Please refer to
      https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.

    Here is a config example:

    .. code:: python

        checkpoint_config = dict(interval=10)

        # To log checkpoint metadata, the interval of checkpoint saving should
        # be divisible by the interval of evaluation.
        evaluation = dict(interval=5)

        log_config = dict(
            ...
            hooks=[
                ...
                dict(type='MMClsWandbHook',
                     init_kwargs={
                         'entity': "YOUR_ENTITY",
                         'project': "YOUR_PROJECT_NAME"
                     },
                     log_checkpoint=True,
                     log_checkpoint_metadata=True,
                     num_eval_images=100)
            ])

    Args:
        init_kwargs (dict): A dict passed to wandb.init to initialize
            a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
            for possible key-value pairs.
        interval (int): Logging interval (every k iterations). Defaults to 10.
        log_checkpoint (bool): Save the checkpoint at every checkpoint interval
            as W&B Artifacts. Use this for model versioning where each version
            is a checkpoint. Defaults to False.
        log_checkpoint_metadata (bool): Log the evaluation metrics computed
            on the validation data with the checkpoint, along with current
            epoch as a metadata to that checkpoint.
            Defaults to True.
        num_eval_images (int): The number of validation images to be logged.
            If zero, the evaluation won't be logged. Defaults to 100.
    """

    def __init__(self,
                 init_kwargs=None,
                 interval=10,
                 log_checkpoint=False,
                 log_checkpoint_metadata=False,
                 num_eval_images=100,
                 **kwargs):
        super(MMClsWandbHook, self).__init__(init_kwargs, interval, **kwargs)

        self.log_checkpoint = log_checkpoint
        self.log_checkpoint_metadata = (
            log_checkpoint and log_checkpoint_metadata)
        self.num_eval_images = num_eval_images
        self.log_evaluation = (num_eval_images > 0)
        self.ckpt_hook: CheckpointHook = None
        self.eval_hook: EvalHook = None

    @master_only
    def before_run(self, runner: BaseRunner):
        super(MMClsWandbHook, self).before_run(runner)

        # Inspect CheckpointHook and EvalHook
        for hook in runner.hooks:
            if isinstance(hook, CheckpointHook):
                self.ckpt_hook = hook
            if isinstance(hook, (EvalHook, DistEvalHook)):
                self.eval_hook = hook

        # Check conditions to log checkpoint
        if self.log_checkpoint:
            if self.ckpt_hook is None:
                self.log_checkpoint = False
                self.log_checkpoint_metadata = False
                runner.logger.warning(
                    'To log checkpoint in MMClsWandbHook, `CheckpointHook` is'
                    'required, please check hooks in the runner.')
            else:
                self.ckpt_interval = self.ckpt_hook.interval

        # Check conditions to log evaluation
        if self.log_evaluation or self.log_checkpoint_metadata:
            if self.eval_hook is None:
                self.log_evaluation = False
                self.log_checkpoint_metadata = False
                runner.logger.warning(
                    'To log evaluation or checkpoint metadata in '
                    'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls '
                    'is required, please check whether the validation '
                    'is enabled.')
            else:
                self.eval_interval = self.eval_hook.interval
            self.val_dataset = self.eval_hook.dataloader.dataset
            if (self.log_evaluation
                    and self.num_eval_images > len(self.val_dataset)):
                self.num_eval_images = len(self.val_dataset)
                runner.logger.warning(
                    f'The num_eval_images ({self.num_eval_images}) is '
                    'greater than the total number of validation samples '
                    f'({len(self.val_dataset)}). The complete validation '
                    'dataset will be logged.')

        # Check conditions to log checkpoint metadata
        if self.log_checkpoint_metadata:
            assert self.ckpt_interval % self.eval_interval == 0, \
                'To log checkpoint metadata in MMClsWandbHook, the interval ' \
                f'of checkpoint saving ({self.ckpt_interval}) should be ' \
                'divisible by the interval of evaluation ' \
                f'({self.eval_interval}).'

        # Initialize evaluation table
        if self.log_evaluation:
            # Initialize data table
            self._init_data_table()
            # Add ground truth to the data table
            self._add_ground_truth()
            # Log ground truth data
            self._log_data_table()

    @master_only
    def after_train_epoch(self, runner):
        super(MMClsWandbHook, self).after_train_epoch(runner)

        if not self.by_epoch:
            return

        # Save checkpoint and metadata
        if (self.log_checkpoint
                and self.every_n_epochs(runner, self.ckpt_interval)
                or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
            if self.log_checkpoint_metadata and self.eval_hook:
                metadata = {
                    'epoch': runner.epoch + 1,
                    **self._get_eval_results()
                }
            else:
                metadata = None
            aliases = [f'epoch_{runner.epoch+1}', 'latest']
            model_path = osp.join(self.ckpt_hook.out_dir,
                                  f'epoch_{runner.epoch+1}.pth')
            self._log_ckpt_as_artifact(model_path, aliases, metadata)

        # Save prediction table
        if self.log_evaluation and self.eval_hook._should_evaluate(runner):
            results = self.eval_hook.latest_results
            # Initialize evaluation table
            self._init_pred_table()
            # Add predictions to evaluation table
            self._add_predictions(results, runner.epoch + 1)
            # Log the evaluation table
            self._log_eval_table(runner.epoch + 1)

    @master_only
    def after_train_iter(self, runner):
        if self.get_mode(runner) == 'train':
            # An ugly patch. The iter-based eval hook will call the
            # `after_train_iter` method of all logger hooks before evaluation.
            # Use this trick to skip that call.
            # Don't call super method at first, it will clear the log_buffer
            return super(MMClsWandbHook, self).after_train_iter(runner)
        else:
            super(MMClsWandbHook, self).after_train_iter(runner)

        if self.by_epoch:
            return

        # Save checkpoint and metadata
        if (self.log_checkpoint
                and self.every_n_iters(runner, self.ckpt_interval)
                or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
            if self.log_checkpoint_metadata and self.eval_hook:
                metadata = {
                    'iter': runner.iter + 1,
                    **self._get_eval_results()
                }
            else:
                metadata = None
            aliases = [f'iter_{runner.iter+1}', 'latest']
            model_path = osp.join(self.ckpt_hook.out_dir,
                                  f'iter_{runner.iter+1}.pth')
            self._log_ckpt_as_artifact(model_path, aliases, metadata)

        # Save prediction table
        if self.log_evaluation and self.eval_hook._should_evaluate(runner):
            results = self.eval_hook.latest_results
            # Initialize evaluation table
            self._init_pred_table()
            # Log predictions
            self._add_predictions(results, runner.iter + 1)
            # Log the table
            self._log_eval_table(runner.iter + 1)

    @master_only
    def after_run(self, runner):
        self.wandb.finish()

    def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
        """Log model checkpoint as  W&B Artifact.

        Args:
            model_path (str): Path of the checkpoint to log.
            aliases (list): List of the aliases associated with this artifact.
            metadata (dict, optional): Metadata associated with this artifact.
        """
        model_artifact = self.wandb.Artifact(
            f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
        model_artifact.add_file(model_path)
        self.wandb.log_artifact(model_artifact, aliases=aliases)

    def _get_eval_results(self):
        """Get model evaluation results."""
        results = self.eval_hook.latest_results
        eval_results = self.val_dataset.evaluate(
            results, logger='silent', **self.eval_hook.eval_kwargs)
        return eval_results

    def _init_data_table(self):
        """Initialize the W&B Tables for validation data."""
        columns = ['image_name', 'image', 'ground_truth']
        self.data_table = self.wandb.Table(columns=columns)

    def _init_pred_table(self):
        """Initialize the W&B Tables for model evaluation."""
        columns = ['epoch'] if self.by_epoch else ['iter']
        columns += ['image_name', 'image', 'ground_truth', 'prediction'
                    ] + list(self.val_dataset.CLASSES)
        self.eval_table = self.wandb.Table(columns=columns)

    def _add_ground_truth(self):
        # Get image loading pipeline
        from mmcls.datasets.pipelines import LoadImageFromFile
        img_loader = None
        for t in self.val_dataset.pipeline.transforms:
            if isinstance(t, LoadImageFromFile):
                img_loader = t

        CLASSES = self.val_dataset.CLASSES
        self.eval_image_indexs = np.arange(len(self.val_dataset))
        # Set seed so that same validation set is logged each time.
        np.random.seed(42)
        np.random.shuffle(self.eval_image_indexs)
        self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]

        for idx in self.eval_image_indexs:
            img_info = self.val_dataset.data_infos[idx]
            if img_loader is not None:
                img_info = img_loader(img_info)
                # Get image and convert from BGR to RGB
                image = img_info['img'][..., ::-1]
            else:
                # For CIFAR dataset.
                image = img_info['img']
            image_name = img_info.get('filename', f'img_{idx}')
            gt_label = img_info.get('gt_label').item()

            self.data_table.add_data(image_name, self.wandb.Image(image),
                                     CLASSES[gt_label])

    def _add_predictions(self, results, idx):
        table_idxs = self.data_table_ref.get_index()
        assert len(table_idxs) == len(self.eval_image_indexs)

        for ndx, eval_image_index in enumerate(self.eval_image_indexs):
            result = results[eval_image_index]

            self.eval_table.add_data(
                idx, self.data_table_ref.data[ndx][0],
                self.data_table_ref.data[ndx][1],
                self.data_table_ref.data[ndx][2],
                self.val_dataset.CLASSES[np.argmax(result)], *tuple(result))

    def _log_data_table(self):
        """Log the W&B Tables for validation data as artifact and calls
        `use_artifact` on it so that the evaluation table can use the reference
        of already uploaded images.

        This allows the data to be uploaded just once.
        """
        data_artifact = self.wandb.Artifact('val', type='dataset')
        data_artifact.add(self.data_table, 'val_data')

        self.wandb.run.use_artifact(data_artifact)
        data_artifact.wait()

        self.data_table_ref = data_artifact.get('val_data')

    def _log_eval_table(self, idx):
        """Log the W&B Tables for model evaluation.

        The table will be logged multiple times creating new version. Use this
        to compare models at different intervals interactively.
        """
        pred_artifact = self.wandb.Artifact(
            f'run_{self.wandb.run.id}_pred', type='evaluation')
        pred_artifact.add(self.eval_table, 'eval_data')
        if self.by_epoch:
            aliases = ['latest', f'epoch_{idx}']
        else:
            aliases = ['latest', f'iter_{idx}']
        self.wandb.run.log_artifact(pred_artifact, aliases=aliases)
