import inspect
from abc import ABCMeta, abstractmethod
from typing import NamedTuple, Any


class PrintMixin:
    def __repr__(self):
        cls = self.__class__.__name__
        idhex = hex(id(self))
        attrs = " ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
        attrs = ": " + attrs if attrs else ""
        return f"<{cls} at {idhex}{attrs}>"


class AbstractSignatureChecker(ABCMeta):
    """
    Meta class for strictly enforcing signatures of @abstractmethod's in an abstract base class
    """
    def __init__(cls, name, bases, attrs):
        errors = []
        for base_cls in bases:
            for meth_name in getattr(base_cls, "__abstractmethods__", ()):
                orig_argspec = inspect.getfullargspec(getattr(base_cls, meth_name))
                target_argspec = inspect.getfullargspec(getattr(cls, meth_name))
                if orig_argspec != target_argspec:
                    errors.append(
                        f"Subclass `{cls.__name__}` of `{base_cls.__name__}` not implemented with correct signature "
                        f"in abstract method {meth_name!r}.\n"
                        f"Expected: {orig_argspec}\n"
                        f"Got: {target_argspec}\n")
        if errors:
            raise TypeError("\n".join(errors))
        super().__init__(name, bases, attrs)


class Distribution(PrintMixin, metaclass=AbstractSignatureChecker):
    """
    Abstract base class for a distribution
    """
    @abstractmethod
    def __call__(self, rng, shape=None):
        """
        Args:
            rng (np.random.Generator): numpy pseudorandom number generator
            shape (tuple, optional): tuple shape for sampling

        Returns:
            ndarray: sample of shape `shape` and `()` if `None`
        """
        pass


class GraphModel(PrintMixin, metaclass=AbstractSignatureChecker):
    """
    Abstract base class for graph model p(G)
    """
    @abstractmethod
    def __call__(self, rng, n_vars):
        """
        Args:
            rng (np.random.Generator): numpy pseudorandom number generator
            n_vars (int): number of nodes in the graph

        Returns:
            ndarray: binary adjacency matrix of shape `[n_vars, n_vars]`
        """
        pass


class MechanismModel(PrintMixin, metaclass=AbstractSignatureChecker):
    """
    Abstract base class for data-generating mechanism p(D|G)
    """
    @abstractmethod
    def __call__(self, rng, g, n_observations_obs, n_observations_int, seed=None):
        """
        Args:
            rng (np.random.Generator): numpy pseudorandom number generator
            g (ndarray): binary adjacency matrix of shape `[n_vars, n_vars]` as generated by a `GraphModel` subclass
            n_observations_obs (int): number of observational data points to be sampled
            n_observations_int (int): number of interventional data points to be sampled

        Returns:
            Data: namedtuple containing `x_obs` and `x_int` data matrices
        """
        pass


class NoiseModel(PrintMixin, metaclass=AbstractSignatureChecker):
    """
    Abstract base class for SCM noise
    """
    @abstractmethod
    def __call__(self, rng, x, is_parent):
        """
        Args:
            rng (np.random.Generator): numpy pseudorandom number generator
            x (ndarray): data matrix of shape `[n_observations, n_vars]`
            is_parent (ndarray): binary vector of shape `[n_vars,]` indicating which nodes are parents of this node

        Returns:
            ndarray: noise sampled of shape `[n_observations]`
        """
        pass


class CustomClassWrapper(NamedTuple):
    name: str
    kwargs: Any
    paths: Any


class SyntheticSpec(NamedTuple):
    """
    Data structure for a single component of the data-generating distribution.
    """
    graph: Any
    mechanism: Any
    n_observations_obs: int = None
    n_observations_int: int = 0


class Data(NamedTuple):
    """
    Data structure for returning data from a `MechanismModel`
    """
    x_obs: Any
    x_int: Any
    is_count_data: bool