R"""Determine whether a set of expressions have elementary antiderivatives.

The generated CSV has two columns: Expression (in sympy format), Has EAD.
The "Has EAD" column can be "1", "0", "-", with "-" indicating a None value
returned by the CAS interface about the existence of its elementary antiderivative.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 scripts1/data_gen/antiderivative/generate_ead_indicator_wolfram2.py \
    --input_file=/tmp/test_expressions \
    --output_file=/tmp/test_expressions_ead \
    --seconds_per_attempt=1 \
    --max_evaluators=8 \
    --flush_every=32


"""
import asyncio
import collections
import csv
import logging
import itertools
import os
from typing import Sequence

from absl import app
from absl import flags

import sympy as sp
from sympy.parsing import mathematica
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm

from wolframclient.evaluation import WolframLanguageSession

from em.datasets.antiderivative import sympy_util as sp_util


FLAGS = flags.FLAGS

flags.DEFINE_string(
    "input_file",
    None,
    "Path to file to read expressions from. Should be in the format of 1 expression per line.",
)
flags.DEFINE_string("output_file", None, "Path to csv file to write expressions to.")

flags.DEFINE_string("wolfram_kernel_path", None, "")

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)

flags.DEFINE_integer("seconds_per_attempt", 1, '')
flags.DEFINE_integer("max_evaluators", 4, '')

flags.DEFINE_integer("flush_every", 50, '')


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


def make_wolfram_input(expr, d_variable_name: str) -> str:
    m_expr = sp.mathematica_code(expr)
    return f'ToString[InputForm[Integrate[{m_expr}, {d_variable_name}]]]'


def read_chunk_of_expressions(fi, symbols):
    exprs = []
    for line in itertools.islice(fi, FLAGS.examples_per_chunk):
        expr = parse_expr(line.strip(), evaluate=True, local_dict=symbols)
        exprs.append(expr)
    return exprs


def get_chunks_iterator(fi, symbols):
    examples_per_chunk = FLAGS.examples_per_chunk
    chunk = []

    for line in fi:
        expr = parse_expr(line.strip(), evaluate=True, local_dict=symbols)
        chunk.append(expr)
        if len(chunk) == examples_per_chunk:
            yield chunk
            chunk = []

    if len(chunk) > 0:
        yield chunk


def get_label(result):
    if result is None:
        return '-'
    result = str(result)
    if result == '$Aborted':
        return '-'
    else:
        try:
            result_expr = mathematica.mathematica(result)
        except ValueError:
            # This likely implies that the integral is non-elementary.
            return '0'

    is_elem = sp_util.is_elementary(result_expr)
    return '1' if is_elem else '0'


def to_csv_rows(chunk, wolfram_results):
    assert len(chunk) == len(wolfram_results)
    return [
        (str(expr), get_label(result))
        for expr, result in zip(chunk, wolfram_results)
    ]


def main(_):
    logging.getLogger().setLevel(logging.ERROR)

    symbols = make_symbols_dict(FLAGS.symbols)
    d_variable_name = FLAGS.symbols[0]

    input_file = os.path.expanduser(FLAGS.input_file)
    output_file = os.path.expanduser(FLAGS.output_file)

    print(FLAGS.seconds_per_attempt)

    with WolframLanguageSession(FLAGS.wolfram_kernel_path) as session:

        with open(input_file, 'r') as fi, open(output_file, 'w', newline='') as fo:
            writer = csv.writer(fo, delimiter=',')
            for i, line in enumerate(tqdm(fi)):
                line = line.strip()
                expr = parse_expr(line, evaluate=True, local_dict=symbols)
                wolfram_expr = make_wolfram_input(expr, d_variable_name)
                result = session.evaluate(wolfram_expr, timeout=FLAGS.seconds_per_attempt)

                writer.writerow([line, get_label(result)])
                print(i)
                
                if ((i + 1) % FLAGS.flush_every) == 0:
                    fo.flush()
                    os.fsync(fo)

            # n_examples = i + 1
        
            # writer = csv.writer(fo, delimiter=',')

            # for i, future in enumerate(tqdm(asyncio.as_completed(aws), total=n_examples)):
            #     try:
            #         integrand, result = await future
            #     except Exception as e:
            #         print('%r generated an exception: %s' % (integrand, e))
            #         result = None

            #     writer.writerow([integrand, get_label(result)])

            #     if ((i + 1) % FLAGS.flush_every) == 0:
            #         fo.flush()
            #         os.fsync(fo)


if __name__ == "__main__":
    app.run(main)
