from algorithm.domain import Unassigned
from typing import Callable, List, Union, Optional, Sequence
import ipdb
class Constraint:
    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):
        return True

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict):
        if len(variables) == 1:
            variable = variables[0]
            domain = domains[variable]
            for value in domain[:]:
                if not self(variables, domains, {variable: value}):
                    domain.remove(value)
            constraints.remove((self, variables))
            vconstraints[variable].remove((self, variables))

    def forwardCheck(self, variables: Sequence, domains: dict, assignments: dict, _unassigned=Unassigned):
        unassignedvariable = _unassigned
        for variable in variables:
            if variable not in assignments:
                if unassignedvariable is _unassigned:
                    unassignedvariable = variable
                else:
                    break
        else:
            if unassignedvariable is not _unassigned:
                # Remove from the unassigned variable domain's all
                # values which break our variable's constraints.
                domain = domains[unassignedvariable]
                if domain:
                    for value in domain[:]:
                        assignments[unassignedvariable] = value
                        if not self(variables, domains, assignments):
                            domain.hideValue(value)
                    del assignments[unassignedvariable]
                if not domain:
                    return False
        return True


class FunctionConstraint(Constraint):
    def __init__(self, func: Callable, assigned: bool = True):
        self._func = func
        self._assigned = assigned

    def __call__(  # noqa: D102
        self,
        variables: Sequence,
        domains: dict,
        assignments: dict,
        forwardcheck=False,
        _unassigned=Unassigned,
    ):

        # single loop list: 0.11462 seconds, Cythonized: 0.08686 seconds
        parms = list()
        missing = 0
        for x in variables:
            if x in assignments:
                parms.append(assignments[x])
            else:
                parms.append(_unassigned)
                missing += 1

        # if there are unassigned variables, do a forward check before executing the restriction function
        if missing > 0:
            return (self._assigned or self._func(*parms)) and (
                not forwardcheck or missing != 1 or self.forwardCheck(variables, domains, assignments)
            )
        return self._func(*parms)


class AllDifferentConstraint(Constraint):
    def __call__(  # noqa: D102
        self,
        variables: Sequence,
        domains: dict,
        assignments: dict,
        forwardcheck=False,
        _unassigned=Unassigned,
    ):
        seen = {}
        for variable in variables:
            value = assignments.get(variable, _unassigned)
            if value is not _unassigned:
                if value in seen:
                    return False
                seen[value] = True
        if forwardcheck:
            for variable in variables:
                if variable not in assignments:
                    domain = domains[variable]
                    for value in seen:
                        if value in domain:
                            domain.hideValue(value)
                            if not domain:
                                return False
        return True


class AllEqualConstraint(Constraint):
    def __call__(   # noqa: D102
        self,
        variables: Sequence,
        domains: dict,
        assignments: dict,
        forwardcheck=False,
        _unassigned=Unassigned,
    ):
        singlevalue = _unassigned
        for variable in variables:
            value = assignments.get(variable, _unassigned)
            if singlevalue is _unassigned:
                singlevalue = value
            elif value is not _unassigned and value != singlevalue:
                return False
        if forwardcheck and singlevalue is not _unassigned:
            for variable in variables:
                if variable not in assignments:
                    domain = domains[variable]
                    if singlevalue not in domain:
                        return False
                    for value in domain[:]:
                        if value != singlevalue:
                            domain.hideValue(value)
        return True


class MaxSumConstraint(Constraint):
    """Constraint enforcing that values of given variables sum up to a given amount.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(MaxSumConstraint(3))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 1)], [('a', 1), ('b', 2)], [('a', 2), ('b', 1)]]
    """

    def __init__(self, maxsum: Union[int, float], multipliers: Optional[Sequence] = None):
        """Initialization method.

        Args:
            maxsum (number): Value to be considered as the maximum sum
            multipliers (sequence of numbers): If given, variable values
                will be multiplied by the given factors before being
                summed to be checked
        """
        self._maxsum = maxsum
        self._multipliers = multipliers

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict): # noqa: D102
        Constraint.preProcess(self, variables, domains, constraints, vconstraints)

        # check if there are any negative values in the associated variables
        variable_contains_negative: list[bool] = list()
        variable_with_negative = None
        for variable in variables:
            contains_negative = any(value < 0 for value in domains[variable])
            variable_contains_negative.append(contains_negative)
            if contains_negative:
                if variable_with_negative is not None:
                    # if more than one associated variables contain negative, we can't prune
                    return
                variable_with_negative = variable

        # prune the associated variables of values > maxsum
        multipliers = self._multipliers
        maxsum = self._maxsum
        if multipliers:
            for variable, multiplier in zip(variables, multipliers):
                if variable_with_negative is not None and variable_with_negative != variable:
                    continue
                domain = domains[variable]
                for value in domain[:]:
                    if value * multiplier > maxsum:
                        domain.remove(value)
        else:
            for variable in variables:
                if variable_with_negative is not None and variable_with_negative != variable:
                    continue
                domain = domains[variable]
                for value in domain[:]:
                    if value > maxsum:
                        domain.remove(value)

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):  # noqa: D102
        multipliers = self._multipliers
        maxsum = self._maxsum
        sum = 0
        if multipliers:
            for variable, multiplier in zip(variables, multipliers):
                if variable in assignments:
                    sum += assignments[variable] * multiplier
            if isinstance(sum, float):
                sum = round(sum, 10)
            if sum > maxsum:
                return False
            if forwardcheck:
                for variable, multiplier in zip(variables, multipliers):
                    if variable not in assignments:
                        domain = domains[variable]
                        for value in domain[:]:
                            if sum + value * multiplier > maxsum:
                                domain.hideValue(value)
                        if not domain:
                            return False
        else:
            for variable in variables:
                if variable in assignments:
                    sum += assignments[variable]
            if isinstance(sum, float):
                sum = round(sum, 10)
            if sum > maxsum:
                return False
            if forwardcheck:
                for variable in variables:
                    if variable not in assignments:
                        domain = domains[variable]
                        for value in domain[:]:
                            if sum + value > maxsum:
                                domain.hideValue(value)
                        if not domain:
                            return False
        return True


class ExactSumConstraint(Constraint):
    """Constraint enforcing that values of given variables sum exactly to a given amount.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(ExactSumConstraint(3))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 2)], [('a', 2), ('b', 1)]]
    """

    def __init__(self, exactsum: Union[int, float], multipliers: Optional[Sequence] = None):
        """Initialization method.

        Args:
            exactsum (number): Value to be considered as the exact sum
            multipliers (sequence of numbers): If given, variable values
                will be multiplied by the given factors before being
                summed to be checked
        """
        self._exactsum = exactsum
        self._multipliers = multipliers

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict): # noqa: D102
        Constraint.preProcess(self, variables, domains, constraints, vconstraints)
        multipliers = self._multipliers
        exactsum = self._exactsum
        if multipliers:
            for variable, multiplier in zip(variables, multipliers):
                domain = domains[variable]
                for value in domain[:]:
                    if value * multiplier > exactsum:
                        domain.remove(value)
        else:
            for variable in variables:
                domain = domains[variable]
                for value in domain[:]:
                    if value > exactsum:
                        domain.remove(value)

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):    # noqa: D102
        multipliers = self._multipliers
        exactsum = self._exactsum
        sum = 0
        missing = False
        if multipliers:
            for variable, multiplier in zip(variables, multipliers):
                if variable in assignments:
                    sum += assignments[variable] * multiplier
                else:
                    missing = True
            if isinstance(sum, float):
                sum = round(sum, 10)
            if sum > exactsum:
                return False
            if forwardcheck and missing:
                for variable, multiplier in zip(variables, multipliers):
                    if variable not in assignments:
                        domain = domains[variable]
                        for value in domain[:]:
                            if sum + value * multiplier > exactsum:
                                domain.hideValue(value)
                        if not domain:
                            return False
        else:
            for variable in variables:
                if variable in assignments:
                    sum += assignments[variable]
                else:
                    missing = True
            if isinstance(sum, float):
                sum = round(sum, 10)
            if sum > exactsum:
                return False
            if forwardcheck and missing:
                for variable in variables:
                    if variable not in assignments:
                        domain = domains[variable]
                        for value in domain[:]:
                            if sum + value > exactsum:
                                domain.hideValue(value)
                        if not domain:
                            return False
        if missing:
            return sum <= exactsum
        else:
            return sum == exactsum


class MinSumConstraint(Constraint):
    """Constraint enforcing that values of given variables sum at least to a given amount.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(MinSumConstraint(3))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 2)], [('a', 2), ('b', 1)], [('a', 2), ('b', 2)]]
    """

    def __init__(self, minsum: Union[int, float], multipliers: Optional[Sequence] = None):
        """Initialization method.

        Args:
            minsum (number): Value to be considered as the minimum sum
            multipliers (sequence of numbers): If given, variable values
                will be multiplied by the given factors before being
                summed to be checked
        """
        self._minsum = minsum
        self._multipliers = multipliers

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):    # noqa: D102
        # check if each variable is in the assignments
        for variable in variables:
            if variable not in assignments:
                return True

        # with each variable assigned, sum the values
        multipliers = self._multipliers
        minsum = self._minsum
        sum = 0
        if multipliers:
            for variable, multiplier in zip(variables, multipliers):
                sum += assignments[variable] * multiplier
        else:
            for variable in variables:
                sum += assignments[variable]
        if isinstance(sum, float):
            sum = round(sum, 10)
        return sum >= minsum


class MaxProdConstraint(Constraint):
    

    def __init__(self, maxprod: Union[int, float]):
        """Instantiate a MaxProdConstraint.

        Args:
            maxprod: Value to be considered as the maximum product
        """
        self._maxprod = maxprod

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict): # noqa: D102
        Constraint.preProcess(self, variables, domains, constraints, vconstraints)

        # check if there are any values less than 1 in the associated variables
        variable_contains_lt1: list[bool] = list()
        variable_with_lt1 = None
        for variable in variables:
            contains_lt1 = any(value < 1 for value in domains[variable])
            variable_contains_lt1.append(contains_lt1)
            if contains_lt1 is True:
                if variable_with_lt1 is not None:
                    # if more than one associated variables contain less than 1, we can't prune
                    return
                variable_with_lt1 = variable

        # prune the associated variables of values > maxprod
        maxprod = self._maxprod
        for variable in variables:
            if variable_with_lt1 is not None and variable_with_lt1 != variable:
                continue
            domain = domains[variable]
            for value in domain[:]:
                if value > maxprod:
                    domain.remove(value)
                elif value == 0 and maxprod < 0:
                    domain.remove(value)

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):    # noqa: D102
        maxprod = self._maxprod
        prod = 1
        for variable in variables:
            if variable in assignments:
                prod *= assignments[variable]
        if isinstance(prod, float):
            prod = round(prod, 10)
        if prod > maxprod:
            return False
        if forwardcheck:
            for variable in variables:
                if variable not in assignments:
                    domain = domains[variable]
                    for value in domain[:]:
                        if prod * value > maxprod:
                            domain.hideValue(value)
                    if not domain:
                        return False
        return True


class MinProdConstraint(Constraint):
    

    def __init__(self, minprod: Union[int, float]):
        """Instantiate a MinProdConstraint.

        Args:
            minprod: Value to be considered as the maximum product
        """
        self._minprod = minprod

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict): # noqa: D102
        Constraint.preProcess(self, variables, domains, constraints, vconstraints)

        # prune the associated variables of values > maxprod
        minprod = self._minprod
        for variable in variables:
            domain = domains[variable]
            for value in domain[:]:
                if value == 0 and minprod > 0:
                    domain.remove(value)

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):    # noqa: D102
        # check if each variable is in the assignments
        for variable in variables:
            if variable not in assignments:
                return True

        # with each variable assigned, sum the values
        minprod = self._minprod
        prod = 1
        for variable in variables:
            prod *= assignments[variable]
        if isinstance(prod, float):
            prod = round(prod, 10)
        return prod >= minprod


class InSetConstraint(Constraint):
    """Constraint enforcing that values of given variables are present in the given set.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(InSetConstraint([1]))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 1)]]
    """

    def __init__(self, set):
        """Initialization method.

        Args:
            set (set): Set of allowed values
        """
        self._set = set

    def __call__(self, variables, domains, assignments, forwardcheck=False):    # noqa: D102
        # preProcess() will remove it.
        raise RuntimeError("Can't happen")

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict): # noqa: D102
        set = self._set
        for variable in variables:
            domain = domains[variable]
            for value in domain[:]:
                if value not in set:
                    domain.remove(value)
            vconstraints[variable].remove((self, variables))
        constraints.remove((self, variables))


class NotInSetConstraint(Constraint):
    """Constraint enforcing that values of given variables are not present in the given set.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(NotInSetConstraint([1]))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 2), ('b', 2)]]
    """

    def __init__(self, set):
        """Initialization method.

        Args:
            set (set): Set of disallowed values
        """
        self._set = set

    def __call__(self, variables, domains, assignments, forwardcheck=False):    # noqa: D102
        # preProcess() will remove it.
        raise RuntimeError("Can't happen")

    def preProcess(self, variables: Sequence, domains: dict, constraints: List[tuple], vconstraints: dict): # noqa: D102
        set = self._set
        for variable in variables:
            domain = domains[variable]
            for value in domain[:]:
                if value in set:
                    domain.remove(value)
            vconstraints[variable].remove((self, variables))
        constraints.remove((self, variables))


class SomeInSetConstraint(Constraint):
    """Constraint enforcing that at least some of the values of given variables must be present in a given set.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(SomeInSetConstraint([1]))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 1)], [('a', 1), ('b', 2)], [('a', 2), ('b', 1)]]
    """

    def __init__(self, set, n=1, exact=False):
        """Initialization method.

        Args:
            set (set): Set of values to be checked
            n (int): Minimum number of assigned values that should be
                present in set (default is 1)
            exact (bool): Whether the number of assigned values which
                are present in set must be exactly `n`
        """
        self._set = set
        self._n = n
        self._exact = exact

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):    # noqa: D102
        set = self._set
        missing = 0
        found = 0
        for variable in variables:
            if variable in assignments:
                found += assignments[variable] in set
            else:
                missing += 1
        if missing:
            if self._exact:
                if not (found <= self._n <= missing + found):
                    return False
            else:
                if self._n > missing + found:
                    return False
            if forwardcheck and self._n - found == missing:
                # All unassigned variables must be assigned to
                # values in the set.
                for variable in variables:
                    if variable not in assignments:
                        domain = domains[variable]
                        for value in domain[:]:
                            if value not in set:
                                domain.hideValue(value)
                        if not domain:
                            return False
        else:
            if self._exact:
                if found != self._n:
                    return False
            else:
                if found < self._n:
                    return False
        return True


class SomeNotInSetConstraint(Constraint):
    """Constraint enforcing that at least some of the values of given variables must not be present in a given set.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2])
        >>> problem.addConstraint(SomeNotInSetConstraint([1]))
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 2)], [('a', 2), ('b', 1)], [('a', 2), ('b', 2)]]
    """

    def __init__(self, set, n=1, exact=False):
        """Initialization method.

        Args:
            set (set): Set of values to be checked
            n (int): Minimum number of assigned values that should not
                be present in set (default is 1)
            exact (bool): Whether the number of assigned values which
                are not present in set must be exactly `n`
        """
        self._set = set
        self._n = n
        self._exact = exact

    def __call__(self, variables: Sequence, domains: dict, assignments: dict, forwardcheck=False):    # noqa: D102
        set = self._set
        missing = 0
        found = 0
        for variable in variables:
            if variable in assignments:
                found += assignments[variable] not in set
            else:
                missing += 1
        if missing:
            if self._exact:
                if not (found <= self._n <= missing + found):
                    return False
            else:
                if self._n > missing + found:
                    return False
            if forwardcheck and self._n - found == missing:
                # All unassigned variables must be assigned to
                # values not in the set.
                for variable in variables:
                    if variable not in assignments:
                        domain = domains[variable]
                        for value in domain[:]:
                            if value in set:
                                domain.hideValue(value)
                        if not domain:
                            return False
        else:
            if self._exact:
                if found != self._n:
                    return False
            else:
                if found < self._n:
                    return False
        return True


class TupleInSetConstraint(Constraint):
    """Constraint enforcing that the values of two variables must be one of the given tuples.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2, 3, 4])
        >>> problem.addConstraint(TupleInSetConstraint([(1, 2), (3, 4), (2, 4)]), ["a", "b"])
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 2)], [('a', 2), ('b', 4)], [('a', 3), ('b', 4)]]
    """
    
    def __init__(self, valid_tuples, costs):
        """Initialization method.
        Args:
            valid_tuples (list of tuples): List of valid (a, b) tuples.
        """
        self.valid_tuples = set(valid_tuples)
        self.costs = costs

    def __call__(self, variables, domains, assignments, forwardcheck=False):
        var1, var2 = variables
        if var1 in assignments and var2 in assignments:
            return (assignments[var1], assignments[var2]) in self.valid_tuples
        
        if forwardcheck and var1 in assignments:
            a = assignments[var1]
            for b in list(domains[var2]):  # Copy to avoid iterator issues
                if (a, b) not in self.valid_tuples:
                    domains[var2].hideValue(b)
            if not domains[var2]:  
                return False
        elif forwardcheck and var2 in assignments:
            b = assignments[var2]
            for a in list(domains[var1]):  # Copy to avoid iterator issues
                if (a, b) not in self.valid_tuples:
                    domains[var1].hideValue(a)
            if not domains[var1]:  
                return False
        return True

    def get_ddeg(self, variables, domains):
        var1, var2 = variables
        ddeg = sum([1 for a in domains[var1] for b in domains[var2] if (a, b) in self.valid_tuples])
        return ddeg
    
    def get_min_cost(self, variables, domains):
        var1, var2 = variables
        return min([[self.costs[a, b] for a in domains[var1] for b in domains[var2] if (a, b) in self.valid_tuples]])

    def get_gap(self, variables, domains, assignments):
        var1, var2 = variables
        gap = 0
        min_cost = self.get_min_cost(variables, domains)
        val1, val2 = assignments[var1], assignments[var2]
        if (val1, val2) in self.valid_tuples:
            gap = (self.costs[val1, val2] - min_cost) / self.costs[val1, val2]
        return gap

    """Constraint enforcing that the values of two variables must not be one of the given tuples.

    Example:
        >>> problem = Problem()
        >>> problem.addVariables(["a", "b"], [1, 2, 3, 4])
        >>> problem.addConstraint(TupleNotInSetConstraint([(1, 2), (3, 4), (2, 4)]), ["a", "b"])
        >>> sorted(sorted(x.items()) for x in problem.getSolutions())
        [[('a', 1), ('b', 1)], [('a', 1), ('b', 3)], [('a', 1), ('b', 4)], [('a', 2), ('b', 1)], [('a', 2), ('b', 2)], [('a', 2), ('b', 3)], [('a', 3), ('b', 1)], [('a', 3), ('b', 2)], [('a', 3), ('b', 3)], [('a', 4), ('b', 1)], [('a', 4), ('b', 2)], [('a', 4), ('b', 3)], [('a', 4), ('b', 4)]]
    """
    
    def __init__(self, forbidden_tuples):
        """Initialization method.
        Args:
            forbidden_tuples (list of tuples): List of forbidden (a, b) tuples.
        """
        self.forbidden_tuples = set(forbidden_tuples)

    def __call__(self, variables, domains, assignments, forwardcheck=False):
        var1, var2 = variables
        if var1 in assignments and var2 in assignments:
            return (assignments[var1], assignments[var2]) not in self.forbidden_tuples
        
        if forwardcheck and var1 in assignments:
            a = assignments[var1]
            for b in list(domains[var2]):  # Copy to avoid iterator issues
                if (a, b) in self.forbidden_tuples:
                    domains[var2].hideValue(b)
            if not domains[var2]:  
                return False
        elif forwardcheck and var2 in assignments:
            b = assignments[var2]
            for a in list(domains[var1]):  # Copy to avoid iterator issues
                if (a, b) in self.forbidden_tuples:
                    domains[var1].hideValue(a)
            if not domains[var1]:  
                return False
        return True

    def get_ddeg(self, variables, domains):
        var1, var2 = variables
        ddeg = sum([1 for a in domains[var1] for b in domains[var2] if (a, b) not in self.forbidden_tuples])
        return ddeg