"""Utilities for SymPy."""
from typing import Optional, Sequence

import sympy as sp

from em.util.color_util import cu

from .misc_util import timeout, TimeoutError
from . import constants

ELEMENTARY_SP_OPERATORS = constants.ELEMENTARY_SP_OPERATORS


def simplify(f, seconds: Optional[int] = None):
    """Simplify an expression."""
    # assert seconds > 0

    # @timeout(seconds)
    def _simplify(f):
        try:
            f2 = sp.simplify(f)
            if any(s.is_Dummy for s in f2.free_symbols):
                # print(cu.hly(f"Detected Dummy symbol when simplifying {f} to {f2}"))
                return f
            else:
                return f2
        except TimeoutError:
            return f
        except Exception as e:
            # print(cu.hlr(f"{type(e).__name__} exception when simplifying {f}"))
            return f
    if seconds is not None:
        _simplify = timeout(seconds)(_simplify)
    return _simplify(f)


def rewrite_sympy_expr(expr, rewrite_functions: Sequence[str]):
    """Rewrite a SymPy expression."""
    expr_rw = expr
    for f in rewrite_functions:
        if f == 'expand':
            expr_rw = sp.expand(expr_rw)
        elif f == 'factor':
            expr_rw = sp.factor(expr_rw)
        elif f == 'expand_log':
            expr_rw = sp.expand_log(expr_rw, force=True)
        elif f == 'logcombine':
            expr_rw = sp.logcombine(expr_rw, force=True)
        elif f == 'powsimp':
            expr_rw = sp.powsimp(expr_rw, force=True)
        elif f == 'simplify':
            expr_rw = simplify(expr_rw, seconds=1)
        else:
            raise ValueError(f'Invalid rewrite function: {f}')
    return expr_rw


def count_occurrences(expr):
    """Count atom occurrences in an expression."""
    if expr.is_Atom:
        return {expr: 1}
    elif expr.is_Add or expr.is_Mul or expr.is_Pow:
        assert len(expr.args) >= 2
        result = {}
        for arg in expr.args:
            sub_count = count_occurrences(arg)
            for k, v in sub_count.items():
                result[k] = result.get(k, 0) + v
        return result
    else:
        assert len(expr.args) == 1, expr
        return count_occurrences(expr.args[0])


def count_occurrences2(expr):
    """Count atom occurrences in an expression."""
    result = {}
    for sub_expr in sp.preorder_traversal(expr):
        if sub_expr.is_Atom:
            result[sub_expr] = result.get(sub_expr, 0) + 1
    return result


def remove_root_constant_terms(expr, variables, mode):
    """Remove root constant terms from a non-constant SymPy expression."""
    variables = variables if type(variables) is list else [variables]
    assert mode in ['add', 'mul', 'pow']
    assert any(x in variables for x in expr.free_symbols)
    if mode == 'add' and expr.is_Add or mode == 'mul' and expr.is_Mul:
        args = [arg for arg in expr.args if any(x in variables for x in arg.free_symbols)]
        if len(args) == 1:
            expr = args[0]
        elif len(args) < len(expr.args):
            expr = expr.func(*args)
    elif mode == 'pow' and expr.is_Pow:
        assert len(expr.args) == 2
        if not any(x in variables for x in expr.args[0].free_symbols):
            return expr.args[1]
        elif not any(x in variables for x in expr.args[1].free_symbols):
            return expr.args[0]
        else:
            return expr
    return expr


def remove_mul_const(f, variables):
    """Remove the multiplicative factor of an expression, and return it."""
    if not f.is_Mul:
        return f, 1
    variables = variables if type(variables) is list else [variables]
    var_args = []
    cst_args = []
    for arg in f.args:
        if any(var in arg.free_symbols for var in variables):
            var_args.append(arg)
        else:
            cst_args.append(arg)
    return sp.Mul(*var_args), sp.Mul(*cst_args)


def extract_non_constant_subtree(expr, variables):
    """Extract a non-constant sub-tree from an equation."""
    last = expr
    while True:
        last = expr
        expr = remove_root_constant_terms(expr, variables, 'add')
        expr = remove_root_constant_terms(expr, variables, 'mul')
        expr = remove_root_constant_terms(expr, variables, 'pow')
        while len(expr.args) == 1:
            expr = expr.args[0]
        if expr == last:
            return expr


def reindex_coefficients(expr, coefficients):
    """Re-index coefficients (i.e. if a1 is there and not a0, replace a1 by a0, and recursively)."""
    coeffs = sorted([x for x in expr.free_symbols if x in coefficients], key=lambda x: x.name)
    for idx, coeff in enumerate(coefficients):
        if idx >= len(coeffs):
            break
        if coeff != coeffs[idx]:
            expr = expr.subs(coeffs[idx], coeff)
    return expr


def reduce_coefficients(expr, variables, coefficients):
    """Reduce coefficients in an expression.

    `sqrt(x)*y*sqrt(1/a0)` -> `a0*sqrt(x)*y`
    `x**(-cos(a0))*y**cos(a0)` -> `x**(-a0)*y**a0`
    """
    temp = sp.Symbol('temp')
    while True:
        last = expr
        for a in coefficients:
            if a not in expr.free_symbols:
                continue
            for subexp in sp.preorder_traversal(expr):
                if a in subexp.free_symbols and not any(var in subexp.free_symbols for var in variables):
                    p = expr.subs(subexp, temp)
                    if a in p.free_symbols:
                        continue
                    else:
                        expr = p.subs(temp, a)
                        break
        if last == expr:
            break
    return expr


def simplify_const_with_coeff(expr, coeff):
    """Simplify expressions with constants and coefficients.

    `sqrt(10) * a0 * x` -> `a0 * x`
    `sin(a0 + x + 9/7)` -> `sin(a0 + x)`
    `a0 + x + 9` -> `a0 + x`
    """
    assert coeff.is_Atom
    for parent in sp.preorder_traversal(expr):
        if any(coeff == arg for arg in parent.args):
            break
    if not (parent.is_Add or parent.is_Mul):
        return expr
    removed = [arg for arg in parent.args if len(arg.free_symbols) == 0]
    if len(removed) > 0:
        removed = parent.func(*removed)
        new_coeff = (coeff - removed) if parent.is_Add else (coeff / removed)
        expr = expr.subs(coeff, new_coeff)
    return expr


def simplify_equa_diff(_eq, required=None):
    """Simplify a differential equation by removing non-zero factors."""
    eq = sp.factor(_eq)
    if not eq.is_Mul:
        return _eq
    args = []
    for arg in eq.args:
        if arg.is_nonzero:
            continue
        if required is None or arg.has(required):
            args.append(arg)
    assert len(args) >= 1
    return args[0] if len(args) == 1 else eq.func(*args)


def smallest_with_symbols(expr, symbols):
    """Return the smallest sub-tree in an expression that contains all given symbols."""
    assert all(x in expr.free_symbols for x in symbols)
    if len(expr.args) == 1:
        return smallest_with_symbols(expr.args[0], symbols)
    candidates = [arg for arg in expr.args if any(x in arg.free_symbols for x in symbols)]
    return smallest_with_symbols(candidates[0], symbols) if len(candidates) == 1 else expr


def smallest_with(expr, symbol):
    """Return the smallest sub-tree in an expression that contains a given symbol."""
    assert symbol in expr.free_symbols
    candidates = [arg for arg in expr.args if symbol in arg.free_symbols]
    if len(candidates) > 1 or candidates[0] == symbol:
        return expr
    else:
        return smallest_with(candidates[0], symbol)


def clean_degree2_solution(expr, x, a8, a9):
    """Clean solutions of second order differential equations."""
    last = expr
    while True:
        for a in [a8, a9]:
            if a not in expr.free_symbols:
                return expr
            small = smallest_with(expr, a)
            if small.is_Add or small.is_Mul:
                counts = count_occurrences2(small)
                if counts[a] == 1 and a in small.args:
                    if x in small.free_symbols:
                        expr = expr.subs(small, small.func(*[arg for arg in small.args if arg == a or x in arg.free_symbols]))
                    else:
                        expr = expr.subs(small, a)
        if expr == last:
            break
        last = expr
    return expr


def has_inf_nan(*args):
    """Detect whether some expressions contain a NaN / Infinity symbol."""
    for f in args:
        if f.has(sp.nan) or f.has(sp.oo) or f.has(-sp.oo) or f.has(sp.zoo):
            return True
    return False


def has_I(*args):
    """Detect whether some expressions contain complex numbers."""
    for f in args:
        if f.has(sp.I):
            return True
    return False

###############################################################################


def is_elementary(expr) -> bool:
    # NOTE: Special functions whose arguments are only constants might be OK but
    # will be considered non-elementary by this function.
    for subexp in sp.preorder_traversal(expr):
        if isinstance(expr, (sp.Symbol, sp.Integer, sp.Rational)) or expr in (sp.E, sp.pi, sp.I):
            continue
        ok = False
        for op_type in ELEMENTARY_SP_OPERATORS:
            if isinstance(expr, op_type):
                ok = True
                break
        if not ok:
            return False

    return True
