"""SymPy implementation of the needed CAS capabilities."""
import dataclasses
from typing import Union

import sympy as sp
from sympy.integrals.risch import NonElementaryIntegral

from . import cas_abcs
from .. import sympy_util as sp_util
from .. import constants

INTEGRAL_FUNCS = constants.INTEGRAL_FUNCS


@dataclasses.dataclass
class IntegrationStats:
    pass


class SympyCas(cas_abcs.CasAbc):

    def _has_elementary_antiderivative(self, expr, d_variable) -> Union[bool, None]:
        # Has the option of returning a None if it cannot tell. It is up
        # to the user to do with they will with that information.

        F = sp.integrate(expr, d_variable, risch=True)

        # NOTE: Not all of these might actually mean a non-elementary integral.
        if isinstance(F, NonElementaryIntegral):
            return False

        F = F.doit()

        if sp_util.has_inf_nan(F) or isinstance(F, NonElementaryIntegral) or F.has(sp.Integral) or F.has(sp.Piecewise):
            return False

        if any(op.func in INTEGRAL_FUNCS for op in sp.preorder_traversal(F)):
            return False

        # NOTE: This might not be the intended behavior.
        # skip invalid expressions
        if sp_util.has_inf_nan(expr, F):
            return False

        return True
