R"""Helper script to be called for determining whether expressions have an elementary antiderivative.

Intended to be ran in the the background and communicated with via stdin and stdout

Is quite conservative with saying whether a script has an elementary antiderivative or not.

Only integrals that are instances of NonElementaryIntegral will be deemed to have no elementary
antiderivative. Only integrals that are compmuted and elementary will be deemed to have an
elementary antiderivative. Everything else will be indeterminate.
"""
import collections
import itertools
import sys
from typing import Sequence

from absl import app
from absl import flags

import sympy as sp
from sympy.integrals.risch import NonElementaryIntegral
from sympy.parsing.sympy_parser import parse_expr

from em.datasets.antiderivative import misc_util
from em.datasets.antiderivative import sympy_util as sp_util


IS_READY_MSG = 'LISTENING_ON_STDINT'


FLAGS = flags.FLAGS

if __name__ == "__main__":
    flags.DEFINE_integer("seconds_per_attempt", 1, '')

    flags.DEFINE_list(
        'symbols',
        ['x'],
        'The names of symbols when reading from sympy. The first will be the `d_variable`.'
    )


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


def _has_ead(expr_str: str, symbols) -> str:
    d_variable, = itertools.islice(symbols.values(), 1)
    expr = parse_expr(expr_str, evaluate=True, local_dict=symbols)

    F = sp.integrate(expr, d_variable, risch=True)

    if isinstance(F, NonElementaryIntegral):
        return '0'

    F = F.doit()

    if isinstance(F, NonElementaryIntegral):
        return '0'
    elif sp_util.is_elementary(F):
        return '1'
    
    return '-'


def has_ead(expr_str: str, symbols) -> str:
    try:
        return misc_util.timeout(FLAGS.seconds_per_attempt)(_has_ead)(expr_str, symbols)
    except misc_util.TimeoutError:
        return '-'


def main(_):
    sys.stdout.write(f'{IS_READY_MSG}\n')
    sys.stdout.flush()
    symbols = make_symbols_dict(FLAGS.symbols)
    while True:
        expr_str = sys.stdin.readline().strip()
        if not expr_str:
            continue
        label = has_ead(expr_str, symbols)
        sys.stdout.write(f'{label}\n')
        sys.stdout.flush()


if __name__ == "__main__":
    app.run(main)
