"""Stuff related to matching expressions."""
import abc
import dataclasses
import collections
import functools
import itertools
from typing import Any, Iterable, List, Optional, Union
import uuid

import sympy as sp

# typedefs

# TODO: There is probably some way to do this actually.
SpFunction = Any


###################################################################################


def _ensure_list(x: Any) -> bool:
    if isinstance(x, (tuple, list)):
        return list(x)
    else:
        return [x]


def _ensure_expr(x: Union[sp.Expr, str]) -> sp.Expr:
    if isinstance(x, sp.Expr):
        return x
    elif isinstance(x, str):
        return sp.sympify(x)
    else:
        raise TypeError


def _ensure_matcher(x):
    if isinstance(x, Matcher):
        return x
    if isinstance(x, int):
        x = f'{x}'
    return Literal(x)


def _random_symbol():
    return sp.Symbol(uuid.uuid4().hex)


###########################################################


def _sort_by_contains(exprs: Iterable[sp.Expr]) -> List[sp.Expr]:
    def cmp(x, y):
        if x.has(y):
            return -1
        elif y.has(x):
            return 1
        return 0
    return list(sorted(exprs, key=functools.cmp_to_key(cmp)))


###################################################################################


class Matcher(abc.ABC):

    @abc.abstractmethod
    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        raise NotImplementedError


###################################################################################


class Literal(Matcher):
    """Note that the literal must be exact.

    Equivalent to the == method between sympy Exprs.
    """

    def __init__(self, literals: Union[str, sp.Expr, Iterable[Union[str, sp.Expr]]]):
        self.literals = [_ensure_expr(x) for x in _ensure_list(literals)]

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        return any(x == expr for x in self.literals)


###################################################################################


class Function(Matcher):

    def __init__(
        self,
        functions: Union[SpFunction, Iterable[SpFunction]],
        argument_matcher: Optional[Matcher] = None,
    ):
        super().__init__()
        self.functions = _ensure_list(functions)
        self.argument_matcher = argument_matcher

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        if expr.func not in self.functions:
            return False
        if self.argument_matcher is None:
            return True
        assert len(expr.args) == 1, 'TODO: Support multi-adic functions.'
        return self.argument_matcher.match(expr.args[0], x)


###################################################################################

class Add(Matcher):
    """Matches an addition operation.

    This assumes addition is commutative and makes sure that each of the
    term_matchers can match a unique term in the sum.

    """
    def __init__(
        self,
        term_matchers: Iterable[Matcher],
    ):
        self.term_matchers = list(term_matchers)

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        if expr.func != sp.Add:
            return False

        terms = expr.args
        if len(terms) < len(self.term_matchers):
            return False

        # TODO: This can probably be made significantly faster.
        matcher_to_matched_indices = [
            [i for i, term in enumerate(terms) if matcher.match(term, x)]
            for matcher in self.term_matchers
        ]
        for term_indices in itertools.product(*matcher_to_matched_indices):
            if len(set(term_indices)) == len(term_indices):
                return True

        return False


###################################################################################


class _OperationsUntilMatch(Matcher):

    def __init__(self, matcher: Matcher, operations: Iterable[SpFunction]):
        super().__init__()
        self.operations = list(operations)
        self.matcher = matcher

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        queue = collections.deque([expr])
        while queue:
            node = queue.popleft()
            if self.matcher.match(node, x):
                return True
            if node.func in self.operations:
                queue.extend(node.args)

        return False


###########################################################


class RingOpsUntil(_OperationsUntilMatch):
    """Ring ops basically means add, subtract, and multiply.

    NOTE: Sympy uses multiplication to represent division. I think it's
    typically like Mul[x, Pow[y, -1]] with simplification occuring sometimes
    such as when y is the power of some term. Not matching power might prevent
    some division matches, but be aware that this will probably match some
    divisions.
    """
    OPS = (sp.Add, sp.Mul)

    def __init__(self, matcher: Matcher):
        super().__init__(matcher, self.OPS)


class AnyUntil(Matcher):

    def __init__(self, matcher: Matcher):
        self.matcher = matcher

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        queue = collections.deque([expr])
        while queue:
            node = queue.popleft()
            if self.matcher.match(node, x):
                return True
            queue.extend(node.args)
        return False


###################################################################################


class Constant(Matcher):

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        # TODO: Maybe do something with free variable?
        return expr.is_constant(x)


class Polynomial(Matcher):
    
    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        return expr.is_polynomial(x)


class PolynomialIn(Matcher):
    def __init__(self, terms: Union[sp.Expr, Iterable[sp.Expr]]):
        self.term_to_symbol = collections.OrderedDict(
            (t, _random_symbol())
            for t in _sort_by_contains(_ensure_list(terms))
        )

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        for term, symbol in self.term_to_symbol.items():
            expr = expr.subs(term, symbol)
        return expr.is_polynomial(*self.term_to_symbol.values())


###################################################################################


class Pow(Matcher):
    def __init__(self, base, exponent=None):
        self.base = _ensure_matcher(base)
        self.exponent = _ensure_matcher(exponent) if exponent is not None else None

    def match(self, expr: sp.Expr, x: sp.Symbol) -> bool:
        if not expr.func == sp.Pow:
            return False
        base, exponent = expr.args
        if not self.base.match(base, x):
            return False
        if self.exponent is not None and not self.exponent.match(exponent, x):
            return False
        return True


###################################################################################
###################################################################################
