# =============================================================================
# Model Utilities
# =============================================================================

import inspect

from gpytorch.constraints import Interval, Positive, LessThan, GreaterThan
from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel
from gpytorch.likelihoods import Likelihood, GaussianLikelihood
from gpytorch.means import Mean, ZeroMean, ConstantMean, LinearMean
from gpytorch.priors import (
    Prior, GammaPrior, NormalPrior, UniformPrior, LogNormalPrior
)



# -----------------------------------------------------------------------------
# Constraint Factory
# -----------------------------------------------------------------------------

class ConstraintFactory:
    """Factory that creates parameter constraint objects."""

    REGISTRY: dict[str, Interval] = {
        "Interval": Interval,
        "Positive": Positive,
        "LessThan": LessThan,
        "GreaterThan": GreaterThan,
    }

    @classmethod
    def create(
        cls,
        name: str,
        **kwargs
    ) -> Interval:

        if name not in cls.REGISTRY:
            raise ValueError(f"Unsupported constraint type: {name}.")
        constraint_cls = cls.REGISTRY[name]
        sig = inspect.signature(constraint_cls.__init__).parameters.values()
        args_list = [p.name for p in sig if p.name != "self"]
        params = {}
        for key, value in kwargs.items():
            if key in args_list:
                params[key] = value
        return constraint_cls(**params)



# -----------------------------------------------------------------------------
# Prior Factory
# -----------------------------------------------------------------------------

class PriorFactory:
    """Factory that creates parameter prior objects."""

    REGISTRY: dict[str, Prior] = {
        "GammaPrior": GammaPrior,
        "NormalPrior": NormalPrior,
        "UniformPrior": UniformPrior,
        "LogNormalPrior": LogNormalPrior,
    }

    @classmethod
    def create(
        cls,
        name: str,
        **kwargs
    ) -> Prior:

        if name not in cls.REGISTRY:
            raise ValueError(f"Unsupported prior type: {name}.")
        prior_cls = cls.REGISTRY[name]
        sig = inspect.signature(prior_cls.__init__).parameters.values()
        args_list = [p.name for p in sig if p.name != "self"]
        params = {}
        for key, value in kwargs.items():
            if key in args_list:
                params[key] = value
        return prior_cls(**params)



# -----------------------------------------------------------------------------
# Mean Factory
# -----------------------------------------------------------------------------

class MeanFactory:
    """Factory that creates GP mean objects."""

    REGISTRY: dict[str, Mean] = {
        "ZeroMean": ZeroMean,
        "LinearMean": LinearMean,
        "ConstantMean": ConstantMean,
    }

    @classmethod
    def create(
        cls,
        name: str,
        **kwargs
    ) -> Mean:

        if name not in cls.REGISTRY:
            raise ValueError(f"Unsupported mean type: {name}.")
        mean_cls = cls.REGISTRY[name]
        sig = inspect.signature(mean_cls.__init__).parameters.values()
        args_list = [p.name for p in sig if p.name != "self"]
        params = {}
        for key, value in kwargs.items():
            if key in args_list:
                if key.endswith("constraint"):
                    params[key] = ConstraintFactory.create(**value)
                elif key.endswith("prior"):
                    params[key] = PriorFactory.create(**value)
                else:
                    params[key] = value
        return mean_cls(**params)



# -----------------------------------------------------------------------------
# Kernel Factory
# -----------------------------------------------------------------------------

class KernelFactory:
    """Factory that creates GP kernel objects."""

    REGISTRY: dict[str, Kernel] = {
        "RBFKernel": RBFKernel,
        "MaternKernel": MaternKernel,
        "ScaleKernel": ScaleKernel,
    }

    @classmethod
    def create(
        cls,
        name: str,
        **kwargs
    ) -> Kernel:

        if name not in cls.REGISTRY:
            raise ValueError(f"Unsupported kernel type: {name}.")
        kernel_cls = cls.REGISTRY[name]
        sig = inspect.signature(kernel_cls.__init__).parameters.values()
        args_list = [p.name for p in sig if p.name != "self"]
        params = {}
        for key, value in kwargs.items():
            if key in args_list:
                if key.endswith("constraint"):
                    params[key] = ConstraintFactory.create(**value)
                elif key.endswith("prior"):
                    params[key] = PriorFactory.create(**value)
                elif key.endswith("kernel") or key.endswith("covar_module"):
                    params[key] = KernelFactory.create(**value)
                elif key.endswith("kernels"):
                    params[key] = [KernelFactory.create(**v) for v in value]
                else:
                    params[key] = value
        return kernel_cls(**params)



# -----------------------------------------------------------------------------
# Likelihood Factory
# -----------------------------------------------------------------------------

class LikelihoodFactory:
    """Factory that creates GP likelihood objects."""

    REGISTRY: dict[str, Likelihood] = {
        "GaussianLikelihood": GaussianLikelihood,
    }

    @classmethod
    def create(
        cls,
        name: str,
        **kwargs
    ) -> Likelihood:

        if name not in cls.REGISTRY:
            raise ValueError(f"Unsupported likelihood type: {name}.")
        likelihood_cls = cls.REGISTRY[name]
        sig = inspect.signature(likelihood_cls.__init__).parameters.values()
        args_list = [p.name for p in sig if p.name != "self"]
        params = {}
        for key, value in kwargs.items():
            if key in args_list:
                if key.endswith("constraint"):
                    params[key] = ConstraintFactory.create(**value)
                elif key.endswith("prior"):
                    params[key] = PriorFactory.create(**value)
                else:
                    params[key] = value
        return likelihood_cls(**params)

