R"""Determine whether a set of expressions have elementary antiderivatives.

The generated CSV has two columns: Expression (in sympy format), Has EAD.
The "Has EAD" column can be "1", "0", "-", with "-" indicating a None value
returned by the CAS interface about the existence of its elementary antiderivative.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 scripts1/data_gen/antiderivative/generate_ead_indicator.py \
    --input_file=/tmp/test_expressions \
    --output_file=/tmp/test_expressions_ead \
    --seconds_per_attempt=5 \
    --n_processes=1

"""
import collections
import csv
import multiprocessing as mp
import os
from typing import Sequence

from absl import app
from absl import flags
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm

from em.datasets.antiderivative.cas import sympy_cas
from em.util.color_util import cu


_CAS_OPTIONS = ['sympy']


FLAGS = flags.FLAGS

flags.DEFINE_string(
    "input_file",
    None,
    "Path to file to read expressions from. Should be in the format of 1 expression per line.",
)
flags.DEFINE_string("output_file", None, "Path to csv file to write expressions to.")

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)

flags.DEFINE_enum("cas", _CAS_OPTIONS[0], _CAS_OPTIONS, '')

flags.DEFINE_integer("seconds_per_attempt", 1, '')

flags.DEFINE_integer("n_processes", 1, '')


def get_cas(cas_str: str):
    assert cas_str == 'sympy'
    return sympy_cas.SympyCas()


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


class EadProcess(mp.Process):

    def __init__(
        self,
        input_queue: mp.Queue,
        output_queue: mp.Queue,
        cas: str,
        symbols: Sequence[str],
        seconds_per_attempt: int,
    ):
        super().__init__()
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.seconds_per_attempt = seconds_per_attempt
        self.cas = get_cas(cas)
        self.symbols = make_symbols_dict(symbols)
        self.d_variable = list(self.symbols.values())[0]

    def has_elementary_antiderivative(self, expr):
        try:
            has_ead = self.cas.has_elementary_antiderivative(expr, self.d_variable, self.seconds_per_attempt)
        except (ValueError, AttributeError, TypeError, OverflowError, NotImplementedError):
            has_ead = None
        except Exception as e:
            print(cu.hlr(e))
            has_ead = None

        if has_ead is None:
            return '-'
        elif has_ead:
            return '1'
        else:
            return '0'

    def run(self):
        while True:
            expr_str = self.input_queue.get()
            expr = parse_expr(expr_str, evaluate=True, local_dict=self.symbols)
            has_ead = self.has_elementary_antiderivative(expr)
            self.output_queue.put((expr_str, has_ead))


def main(_):
    input_queue = mp.Queue()
    output_queue = mp.Queue()

    processes = [
        EadProcess(
            input_queue=input_queue,
            output_queue=output_queue,
            cas=FLAGS.cas,
            symbols=FLAGS.symbols,
            seconds_per_attempt=FLAGS.seconds_per_attempt,
        )
        for _ in range(FLAGS.n_processes)
    ]
    for p in processes:
        p.start()

    # NOTE: We read everything into memory, which could cause issues for
    # large input files.
    input_file = os.path.expanduser(FLAGS.input_file)
    with open(input_file, 'r') as fi:
        n_examples = 0
        for line in fi:
            input_queue.put(line.strip())
            n_examples += 1

    # Start writing to the output.
    output_file = os.path.expanduser(FLAGS.output_file)
    with open(output_file, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        for _ in tqdm(range(n_examples)):
            expr, has_ead = output_queue.get()
            writer.writerow([expr, has_ead])

    for p in processes:
        p.terminate()


if __name__ == "__main__":
    app.run(main)
