""" Condition module. """
from .label_tensor import LabelTensor
from .geometry import Location
from .equation.equation import Equation

def dummy(a):
    """Dummy function for testing purposes."""
    return None

class Condition:
    """
    The class `Condition` is used to represent the constraints (physical
    equations, boundary conditions, etc.) that should be satisfied in the
    problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object.
    Conditions can be specified in three ways:

        1. By specifying the input and output points of the condition; in such a
        case, the model is trained to produce the output points given the input
        points.

        2. By specifying the location and the equation of the condition; in such
        a case, the model is trained to minimize the equation residual by
        evaluating it at some samples of the location.

        3. By specifying the input points and the equation of the condition; in
        such a case, the model is trained to minimize the equation residual by
        evaluating it at the passed input points.

    Example::

    >>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
    >>> def example_dirichlet(input_, output_):
    >>>     value = 0.0
    >>>     return output_.extract(['u']) - value
    >>> example_input_pts = LabelTensor(
    >>>     torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
    >>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
    >>> 
    >>> Condition(
    >>>     input_points=example_input_pts,
    >>>     output_points=example_output_pts)
    >>> Condition(
    >>>     location=example_domain,
    >>>     equation=example_dirichlet)
    >>> Condition(
    >>>     input_points=example_input_pts,
    >>>     equation=example_dirichlet)

    """

    __slots__ = [
        'input_points', 'output_points', 'location', 'equation',
        'data_weight'
    ]

    def _dictvalue_isinstance(self, dict_, key_, class_):
        """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
        if key_ not in dict_.keys():
            return True

        return isinstance(dict_[key_], class_)

    def __init__(self, *args, **kwargs):
        """
        Constructor for the `Condition` class.
        """
        self.data_weight = kwargs.pop('data_weight', 1.0)

        if len(args) != 0:
            raise ValueError('Condition takes only the following keyword arguments: {`input_points`, `output_points`, `location`, `function`, `data_weight`}.')

        if (
            sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
            sorted(kwargs.keys()) != sorted(['location', 'equation']) and
            sorted(kwargs.keys()) != sorted(['input_points', 'equation'])
        ):
            raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')

        if not self._dictvalue_isinstance(kwargs, 'input_points', LabelTensor):
            raise TypeError('`input_points` must be a torch.Tensor.')
        if not self._dictvalue_isinstance(kwargs, 'output_points', LabelTensor):
            raise TypeError('`output_points` must be a torch.Tensor.')
        if not self._dictvalue_isinstance(kwargs, 'location', Location):
            raise TypeError('`location` must be a Location.')
        if not self._dictvalue_isinstance(kwargs, 'equation', Equation):
            raise TypeError('`equation` must be a Equation.')

        for key, value in kwargs.items():
            setattr(self, key, value)
