# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod

from mmcv.runner import BaseModule
from mmcv.utils import import_modules_from_strings


def function_wrapper(ctx, method, method_str):
    """Pass teacher's outputs to student."""

    def wrapper(*args, **kwargs):
        # record inputs
        ctx.method_args[method_str] = args
        ctx.method_kwargs[method_str] = kwargs
        # TODO cover more usecases, not only pass teacher's outputs to
        # student.
        if ctx.is_teacher:
            # execute the raw function
            outputs = method(*args, **kwargs)
            # record outputs
            ctx.method_return[method_str] = outputs
        else:
            # modify student's outputs to be same with teacher
            outputs = ctx.method_return[method_str]

        return outputs

    return wrapper


class FunctionContext():
    """Function context manager for rewrite function.

    Args:
        ctx (ConversionContext): The distiller's overall context manager.
        method (str): The name of the function to rewrite.
    """

    def __init__(self, ctx, method, import_module=None):
        self.ctx = ctx

        self.import_module = import_modules_from_strings(import_module)
        self.method_str = method
        self.method_exec_str = f'self.import_module.{method}'

    def _set_method(self, method):
        """Modify a function."""
        exec(f'{self.method_exec_str} = method')

    def __enter__(self):
        """Rewrite the function."""
        self.method_impl = eval(self.method_exec_str)

        if self.method_impl:
            self._set_method(
                function_wrapper(self.ctx, self.method_impl, self.method_str,
                                 self.align_mode))

    def __exit__(self, exc_type, exc_value, traceback):
        """Restore the function."""
        if self.method_impl:
            self._set_method(self.method_impl)


class ConversionContext():
    """Context manager for record functions' inputs or outputs."""

    def __init__(self, hooks):
        # save functions' inputs
        self.method_args = dict()
        self.method_kwargs = dict()
        # save functions' outputs
        self.method_return = dict()

        # Each function will have a sub context manager, the function will be
        # rewritten when enter the sub context manager.
        self.hooks = []
        self.is_teacher = True
        for hook in hooks:
            self.hooks.append(FunctionContext(self, **hook))

    def __enter__(self):
        """Enter every sub context managers."""
        for hook in self.hooks:
            hook.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """Exit every sub context managers."""
        for hook in self.hooks:
            hook.__exit__(exc_type, exc_value, traceback)


class BaseDistiller(BaseModule, metaclass=ABCMeta):
    """Base Distiller.

    In the distillation algorithm, some intermediate results of the teacher
    need to be obtained and passed to the student.

    For nn.Module's outputs, obtained by pytorch forward hook.
    For python function's outputs, obtained by a specific context manager.

    Args:
        align_methods (dict): The details of the functions which outputs need
        to be obtained.
    """

    def __init__(self, align_methods=None, **kwargs):
        super(BaseDistiller, self).__init__(**kwargs)

        if align_methods is None:
            self.context_manager = None
        else:
            # To obtain the python function's outputs, there will build a
            # specific context manager. When enter the context manager, the
            # functions will be rewrite. The context manager could record
            # inputs or outputs of the functions , and pass from teachr to
            # student. When exit the context manager, the rewritten functions
            # will restore.
            self.context_manager = ConversionContext(align_methods)

    @abstractmethod
    def prepare_from_student(self, supernet):
        """Register forward hooks to students and teachers."""
        pass

    @abstractmethod
    def teacher_forward_output_hook(self, module, inputs, outputs):
        """Save the teacher output."""
        pass

    @abstractmethod
    def student_forward_output_hook(self, module, inputs, outputs):
        """Save the student output."""
        pass

    def reset_ctx_teacher_mode(self, mode=True):
        if self.context_manager is not None:
            self.context_manager.is_teacher = mode

    @abstractmethod
    def exec_teacher_forward(self, data):
        """Execute the teacher's forward function."""
        pass

    @abstractmethod
    def exec_student_forward(self, student, data):
        """Execute the student's forward function."""
        pass

    @abstractmethod
    def compute_distill_loss(self, data):
        """Compute distill loss according teacher's outputs and student's
        outputs."""
        pass
