# Import Python packages.
import abc
import datetime
import functools
import json
import os
import re
import shutil
import tempfile
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Mapping,
    NamedTuple,
    Optional,
    Tuple,
    Type,
    TypeVar,
    cast,
    get_type_hints,
)

# Import external packages.
import numpy as np
from typeguard import TypeCheckError, check_type

# Import relatively from other modules.
from ...io import WSP, mkdirs, rmtree
from ...time import wait_until_true
from ...types import NPANYS


# Type variables.
Input = TypeVar("Input")
Output = TypeVar("Output")


# Self types.
SelfTypeTransform = TypeVar("SelfTypeTransform", bound="TypeTransform")
SelfBaseTransform = TypeVar("SelfBaseTransform", bound="BaseTransform[Any, Any]")


class ErrorTransformUnsupport(Exception):
    r"""
    Exception for unsupported functionality of transformation.
    Pay attention that this should not be used for partially unsupported functionality with
    conditions.
    """


class ErrorTransformUnsupportPartial(Exception):
    r"""
    Exception for partially unsupported functionality of transformation.
    Raise this if functionality is defined, but is conditionally unsupported due to configurations,
    arguments, accessibilities or other factors.
    """


class ErrorTransformInvalidDefinition(Exception):
    r"""
    Exception for invalid functionality definition of transformation.
    """


class TransformTags(NamedTuple):
    r"""
    Tags of the transformation.
    """
    # Annotations.
    inplacable: bool
    invertible: bool
    parametric: bool
    hierarchy: bool


# Create transformation class registration container.
_REGISTRATIONS: Dict[str, Type["BaseTransform[Any, Any]"]]
_REGISTRATIONS = {}


class TypeTransform(abc.ABCMeta):
    r"""
    Metaclass of transformation.
    """

    @staticmethod
    def __new__(
        cls: Type["TypeTransform"], /, *args: Any, **kwargs: Any
    ) -> Type["BaseTransform[Any, Any]"]:
        r"""
        Create a new instance of the class.

        Args
        ----
        - cls
            The class (type) of creating new instance.
            The new instance is first created as generic (any) type, then is casted into the class
            (type).

        Returns
        -------
        - obj
            A new instance casted by the class (type).
        """
        # Create a class by super call.
        # Pay attention that create a type class requires casting.
        transform = cast(Type["BaseTransform[Any, Any]"], abc.ABCMeta.__new__(cls, *args, **kwargs))

        # Register the transformation class by its default identifier.
        identifier = transform._IDENTIFIER
        TypeTransform.register_transform(transform, identifier)
        return transform

    @staticmethod
    def register_transform(transform: Type["BaseTransform[Any, Any]"], identifier: str, /) -> None:
        r"""
        Register given transform by given identifier.

        Args
        ----

        Returns
        -------
        """
        # Register the transformation class.
        global _REGISTRATIONS
        assert (
            re.fullmatch(r"[a-zA-Z0-9_]+(\.[a-zA-Z0-9_]+)*", identifier) is not None
        ), f'Registering transformation identifier: "{identifier:s}" is in wrong format.'
        assert (
            identifier not in _REGISTRATIONS
        ), f'Registering transformation identifier: "{identifier:s}" is already registered .'
        _REGISTRATIONS[identifier] = transform

    def __call__(
        self: SelfTypeTransform, /, *args: Any, **kwargs: Any
    ) -> "BaseTransform[Any, Any]":
        r"""
        Call the class as a function.

        Args
        ----

        Returns
        -------
        - obj
            A new instance typed by the class (type).
        """
        # Create an instance by super call.
        transform = cast("BaseTransform[Any, Any]", abc.ABCMeta.__call__(self, *args, **kwargs))

        # Validate transformation tagging.
        check_type(transform.tags, TransformTags)
        return transform


class BaseTransform(Generic[Input, Output], abc.ABC, metaclass=TypeTransform):
    r"""
    Base of transformation.
    """
    # Annotations at class level.
    _IDENTIFIER: str

    # Transformation unique identifier.
    _IDENTIFIER = "_base"

    def __annotate__(self: SelfBaseTransform, /) -> None:
        r"""
        Annotate attributes at instance level.

        Args
        ----

        Returns
        -------
        """
        # Annotations at instance level.
        self._children: Mapping[str, "BaseTransform[Any, Any]"]

    @staticmethod
    def __new__(
        cls: Type["BaseTransform[Any, Any]"], /, *args: Any, **kwargs: Any
    ) -> "BaseTransform[Any, Any]":
        r"""
        Create a new instance of the class.

        Args
        ----
        - cls
            The class (type) of creating new instance.
            The new instance is first created as generic (any) type, then is casted into the class
            (type).

        Returns
        -------
        - obj
            A new instance casted by the class (type).
        """
        # Create a new instance by super call.
        obj = object.__new__(cls)

        # Ensure annotation can be silently covered in unit testing.
        obj.__annotate__()
        return obj

    def __init__(
        self: SelfBaseTransform,
        /,
        *args: Any,
        cache_prefix: Optional[str] = None,
        cache_suffix: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        r"""
        Initialize the class.

        Args
        ----
        - cache_prefix
            Transformation cache directory prefix, e.g., root cache directory.
            It is used to store potentially useful midterm content, e.g., log of all searching
            hyperparameter combinations, while only the top-k hyperparameter combinations will be
            saved as transformation parameters for resuming.
            If it is None, a temporary cache will be used.
        - cache_suffix
            Transformation cache directory suffix.
            It should be a unique string used to distinguish different calls of the same type of
            transformation.
            If it is None, it will not distinguish different calls, and they will share the same
            cache.

        Returns
        -------
        """
        # Save cache with autofilling.
        if cache_prefix is None:
            # Create a temporary cache.
            self._cache = str(tempfile.mkdtemp())
            self._is_temporary_cache = True
            wait_until_true(os.path.isdir, self._cache)
        else:
            # Create an explicit cache.
            self._cache = cache_prefix
            self._is_temporary_cache = False
        self._cache = os.path.join(self._cache, self._IDENTIFIER.replace(".", WSP))
        if cache_suffix is not None:
            # Add suffix to cache path to distinguish different calls.
            self._cache = os.path.join(self._cache, cache_suffix)
        mkdirs(self._cache, must_be_new=True, allow_auto_remove=False)

        # Get typeguarded "children" attribute from arguments.
        try:
            # Get typeguarded value from keyword arguments.
            children = check_type(kwargs["children"], Mapping[str, "BaseTransform[Any, Any]"])
        except KeyError:
            # Default value is defined.
            children = {}

        # Call initialization segments one by one.
        self.__init_children__(children)

    def __init_children__(
        self: SelfBaseTransform, children: Mapping[str, "BaseTransform[Any, Any]"], /  # noqa: W504
    ) -> None:
        r"""
        Initialize children of the class.

        Args
        ----
        - children
            Children transformations.

        Returns
        -------
        """
        # Directly link with given children transformation collection.
        # Children collection is assumed to be immutable, so linking should be fine.
        self._children = children

        # Every child must have typed accessibility.
        for name in self._children:
            # Use getting method to force typing for each child.
            assert hasattr(self, f"get_child_{name:s}"), f'"get_child_{name:s}" is not defined.'
        invalid = []
        for name in dir(self):
            # Prefix "get_child_" is reversed only for child typing.
            if name.startswith("get_child_") and name[10:] not in self._children:
                # Record the invalid case, and raise error with all of them later.
                # Wrap with string quotes for ease of error logging.
                invalid.append(f'"{name:s}"')
        invalid = [*invalid[:3], "..."] if len(invalid) > 3 else invalid
        if invalid:
            # Raise error if at least one invalid definition is captured.
            raise ErrorTransformInvalidDefinition(
                "Invalid method definitions which are reversed for child typing: {:s}.".format(
                    ", ".join(invalid)
                )
            )
        for name in self._children:
            # Type guard every child.
            get = getattr(self, f"get_child_{name:s}")
            annotation = get_type_hints(get)["return"]
            child = get()
            if not issubclass(type(child), BaseTransform):
                # Child of transformation must be transformation, thus its type must be subclass of
                # BaseTransform.
                raise ErrorTransformInvalidDefinition(
                    f'Child "{name:s}" type "{str(type(child)):s}" is not a subclass of'
                    f' "BaseTransform".'
                )
            try:
                # Use typeguard to compare child runtime type with its type annotation.
                check_type(child, annotation)
            except TypeCheckError:
                # Child runtime type must be consistent with its type annotation.
                raise ErrorTransformInvalidDefinition(
                    f'Child "{name:s}" type "{str(type(child)):s}" is different from its'
                    f' annotation "{str(annotation):s}".'
                )

    def __del__(self: SelfBaseTransform, /) -> None:
        r"""
        Delete the class.

        Args
        ----

        Returns
        -------
        """
        # Remove only temporary cache.
        # Pay attention that temporary cache can be already removed before this instance is deleted,
        # e.g., when error raises during test.
        if self._is_temporary_cache and os.path.isdir(self._cache):
            # Remove cache and all content in it.
            rmtree(self._cache)

    @classmethod
    def _is_defined(
        cls: Type[SelfBaseTransform], f: Callable[..., Any], /, *args: Any, **kwargs: Any
    ) -> bool:
        r"""
        Get definition status by calling method with given arguments.

        Args
        ----

        Returns
        -------
        - flag
            Definition status.
        """
        # Collect implementation flag.
        try:
            # Call the method.
            f(*args, **kwargs)
        except ErrorTransformUnsupport:
            # The method is not defined.
            return False
        except Exception:
            # Even though the method fails, the method is still defined.
            return True
        else:
            # The method is defined.
            return True

    @functools.cached_property
    def _is_defined_transform_(self: SelfBaseTransform, /) -> bool:
        r"""
        Get definition status of transformation with inplacement.

        Args
        ----

        Returns
        -------
        - flag
            Definition status.
        """
        # Collect implementation flag of transformation with inplacement.
        return self._is_defined(self.transform_, self.input(None))

    @functools.cached_property
    def _is_defined_inverse(self: SelfBaseTransform, /) -> bool:
        r"""
        Get definition status of inversion without inplacement.

        Args
        ----

        Returns
        -------
        - flag
            Definition status.
        """
        # Collect implementation flag of inversion without inplacement.
        return self._is_defined(self.inverse, self.output(None))

    @functools.cached_property
    def _is_defined_inverse_(self: SelfBaseTransform, /) -> bool:
        r"""
        Get definition status of inversion with inplacement.

        Args
        ----

        Returns
        -------
        - flag
            Definition status.
        """
        # Collect implementation flag of inversion with inplacement.
        return self._is_defined(self.inverse_, self.output(None))

    @functools.cached_property
    def _is_defined_fit(self: SelfBaseTransform, /) -> bool:
        r"""
        Get definition status of parameter fitting.

        Args
        ----

        Returns
        -------
        - flag
            Definition status.
        """
        # Collect implementation flag of parameter fitting.
        return self._is_defined(self.fit, self.input(None), self.output(None))

    @property
    def tags(self: SelfBaseTransform, /) -> TransformTags:
        r"""
        Tags of the transformation.

        Args
        ----

        Returns
        -------
        - tags
            Tags of the transformation.
        """
        # Summarize flags.
        inplacable = self._is_defined_transform_
        invertible = self._is_defined_inverse
        parametric = self._is_defined_fit
        hierarchy = len(self._children) > 0

        # Validate flags.
        consistent = self._is_defined_transform_ == self._is_defined_inverse_
        assert (
            not invertible or consistent
        ), "Inplacement operation status is not consistent for invertible transformation."

        # Construct final flag container.
        return TransformTags(inplacable, invertible, parametric, hierarchy)

    @abc.abstractmethod
    def input(self: SelfBaseTransform, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """

    @abc.abstractmethod
    def output(self: SelfBaseTransform, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """

    @abc.abstractmethod
    def transform(self: SelfBaseTransform, input: Input, /, *args: Any, **kwargs: Any) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - output
            Output from the transformation.
        """

    def transform_(
        self: SelfBaseTransform, input: Input, /, *args: Any, **kwargs: Any
    ) -> SelfBaseTransform:
        r"""
        Transform input with inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # By default, no operation is valid.
        raise ErrorTransformUnsupport(
            f'Transformation with inplacement is not defined for "{self._IDENTIFIER:s}"'
            f" transformation."
        )

    def inverse(self: SelfBaseTransform, output: Output, /, *args: Any, **kwargs: Any) -> Input:
        r"""
        Inverse output back into input without inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - input
            Input to the transformation.
        """
        # By default, no operation is valid.
        raise ErrorTransformUnsupport(
            f'Inversed transformation is not defined for "{self._IDENTIFIER:s}" transformation.'
        )

    def inverse_(
        self: SelfBaseTransform, output: Output, /, *args: Any, **kwargs: Any
    ) -> SelfBaseTransform:
        r"""
        Inverse output back with inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # By default, no operation is valid.
        raise ErrorTransformUnsupport(
            f'Inversed transformation with inplacement is not defined for "{self._IDENTIFIER:s}"'
            f" transformation."
        )

    def fit(
        self: SelfBaseTransform, input: Input, output: Output, /, *args: Any, **kwargs: Any
    ) -> SelfBaseTransform:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # By default, no operation is valid.
        raise ErrorTransformUnsupport(
            f'Fitting is not defined for "{self._IDENTIFIER:s}" transformation.'
        )

    def fit_transform(
        self: SelfBaseTransform,
        example: Tuple[Input, Output],
        input: Input,
        /,
        *args: Any,
        **kwargs: Any,
    ) -> Output:
        r"""
        Fit transformation parameters by example, then transform input into output without
        inplacement.

        Args
        ----
        - example
            example for the transformation.
        - input
            Input to the transformation.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # It is simply a combination of two operations.
        self.fit(*example, *args, **kwargs)
        return cast(Output, self.transform(input, *args, **kwargs))

    def fit_transform_(
        self: SelfBaseTransform,
        example: Tuple[Input, Output],
        input: Input,
        /,
        *args: Any,
        **kwargs: Any,
    ) -> SelfBaseTransform:
        r"""
        Fit transformation parameters by example, then transform input with inplacement.

        Args
        ----
        - example
            example for the transformation.
        - input
            Input to the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # It is simply a combination of two operations.
        self.fit(*example, *args, **kwargs)
        self.transform_(input, *args, **kwargs)
        return self

    def fit_inverse(
        self: SelfBaseTransform,
        example: Tuple[Input, Output],
        output: Output,
        /,
        *args: Any,
        **kwargs: Any,
    ) -> Input:
        r"""
        Fit transformation parameters by example, then inverse output back into input without
        inplacement.

        Args
        ----
        - example
            example for the transformation.
        - output
            Output from the transformation.

        Returns
        -------
        - input
            Input to the transformation.
        """
        # It is simply a combination of two operations.
        self.fit(*example, *args, **kwargs)
        return cast(Input, self.inverse(output, *args, **kwargs))

    def fit_inverse_(
        self: SelfBaseTransform,
        example: Tuple[Input, Output],
        output: Output,
        /,
        *args: Any,
        **kwargs: Any,
    ) -> SelfBaseTransform:
        r"""
        Fit transformation parameters by example, then inverse output back with inplacement.

        Args
        ----
        - example
            example for the transformation.
        - output
            Output from the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # It is simply a combination of two operations.
        self.fit(*example, *args, **kwargs)
        self.inverse_(output, *args, **kwargs)
        return self

    @abc.abstractmethod
    def get_metadata(self: SelfBaseTransform, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """

    @abc.abstractmethod
    def get_numeric_data(self: SelfBaseTransform, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """

    @abc.abstractmethod
    def get_alphabetic_data(self: SelfBaseTransform, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """

    @abc.abstractmethod
    def set_metadata(self: SelfBaseTransform, metadata: Mapping[str, Any], /) -> SelfBaseTransform:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """

    @abc.abstractmethod
    def set_numeric_data(
        self: SelfBaseTransform, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfBaseTransform:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """

    @abc.abstractmethod
    def set_alphabetic_data(
        self: SelfBaseTransform, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfBaseTransform:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """

    def save(
        self: SelfBaseTransform,
        path: str,
        /,
        *,
        allow_auto_backup: bool = True,
        copy: bool = False,
        must_be_new: bool = True,
        allow_auto_remove: bool = True,
        timeout: Optional[float] = None,
        interval: Optional[float] = None,
    ) -> SelfBaseTransform:
        r"""
        Save essential data of transformation on file system.

        Args
        ----
        - path
            Path to save essential data of the transformation.
        - allow_auto_backup
            If True, automatic backup is allowed if the save directory already exists.
            Existing directory will be renamed by appending a datetime string.
            Otherwise, do nothing for backup.
        - copy
            If True, make a copy to generate backup.
            Otherwise, rename to generate backup.
        - must_be_new
            If True, the save directory must be newly made, thus removal is required if it already
            exists.
            Otherwise, it will do nothing if the save directory already exists.
        - allow_auto_remove
            If True, automatic removal is allowed before save directory making.
            Otherwise, removal must be explicitly called before calling this function.
        - timeout
            Maximum seconds to wait before raising timeout error in allocating save directory.
        - interval
            The interval in seconds to check removal status in allocating save directory.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Collect essential flags.
        is_existing = os.path.isdir(path)

        # Backup if related conditions are satisfied.
        if is_existing and allow_auto_backup:
            # Make the backup for existing directory based on arguments.
            self.backup(path, copy=copy)

        # There will be various types of data saved in the same directory.
        mkdirs(
            path,
            must_be_new=must_be_new,
            allow_auto_remove=allow_auto_remove,
            timeout=timeout,
            interval=interval,
        )

        # Save metadata in JSON format.
        metadata = self.get_metadata()
        try:
            # Identifier in metadata has special requirement.
            assert metadata["_identifier"] == self._IDENTIFIER, (
                "Transformation identifier in metadata must be consistent with the transformation"
                " itself."
            )
        except KeyError:
            # If identifier is not provided in metadata, fill automatically.
            metadata = {"_identifier": self._IDENTIFIER, **metadata}
        with open(os.path.join(path, "metadata.json"), "w") as file:
            # Use readable JSON format.
            json.dump(metadata, file, indent=4)

        # Save numeric data in NumPy Zipfile format.
        numeric_data = self.get_numeric_data()
        np.savez(os.path.join(path, "numeric.npz"), **numeric_data)

        # Save alphabetic data in JSON format.
        alphabetic_data = self.get_alphabetic_data()
        with open(os.path.join(path, "alphabetic.json"), "w") as file:
            # Use readable JSON format.
            json.dump(alphabetic_data, file, indent=4)

        # Save all transformation children.
        # Pay attention to save an additional index file for the ease and safety of reproduction.
        with open(os.path.join(path, "children.json"), "w") as file:
            # Use readable JSON format.
            json.dump(list(self._children.keys()), file, indent=4)
        for name, child in self._children.items():
            # Save every child in independent directory indexed by its name.
            # Since we have already made necessary backup including children before reaching here,
            # no backup should be performed for children to avoid redundancy.
            child.save(
                os.path.join(path, name),
                allow_auto_backup=False,
                copy=False,
                must_be_new=must_be_new,
                allow_auto_remove=allow_auto_remove,
                timeout=timeout,
                interval=interval,
            )
        return self

    def load(self: SelfBaseTransform, path: str, /) -> SelfBaseTransform:
        r"""
        Load essential data of transformation from file system.

        Args
        ----
        - path
            Path to load essential data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Overwrite metadata by loaded data.
        with open(os.path.join(path, "metadata.json"), "r") as file:
            # Load metadata in JSON format.
            metadata = json.load(file)
        self.set_metadata(metadata)

        # Load numeric data in NumPy Zipfile format and overwrite.
        npzfile = np.load(os.path.join(path, "numeric.npz"), allow_pickle=True)
        numeric_data = {name: npzfile[name] for name in npzfile.files}
        self.set_numeric_data(numeric_data)

        # Overwrite alphabetic data by loaded data.
        with open(os.path.join(path, "alphabetic.json"), "r") as file:
            # Load alphabetic data in JSON format.
            alphabetic_data = json.load(file)
        self.set_alphabetic_data(alphabetic_data)

        # Load all transformation children.
        # Use children index save on file system, thus an extension can still is compatible with
        # data of origin.
        with open(os.path.join(path, "children.json"), "r") as file:
            # Load children names as loading index.
            names = json.load(file)
        diff = ['"{name:s}"' for name in set(names) - set(self._children.keys())]
        assert (
            len(diff) == 0
        ), 'Fail to load from "{:s}" since it has unsupported children: {:s}.'.format(
            path, ", ".join(diff)
        )
        for name in names:
            # Load every child in independent directory indexed by its name.
            self._children[name].load(os.path.join(path, name))
        return self

    def backup(self: SelfBaseTransform, path: str, /, *, copy: bool = False) -> SelfBaseTransform:
        r"""
        Backup essential data of transformation on file system.

        Args
        ----
        - path
            Path to save essential data of the transformation.
        - copy
            If True, make a copy to generate backup.
            Otherwise, rename to generate backup.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Simply append a datetime string to create the backup.
        # Target path must not be occupied for safety.
        suffix = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        src = path
        dst = WSP.join([path, suffix])
        assert not os.path.isdir(dst), f'Backup target "{dst:s}" already exists.'

        # Backup based on specified mode.
        if copy:
            # Copy mode.
            shutil.copytree(src, dst)
        else:
            # Rename mode.
            os.rename(src, dst)
        return self
