R"""Determine whether a set of expressions have elementary antiderivatives.

The generated CSV has two columns: Expression (in sympy format), Has EAD.
The "Has EAD" column can be "1", "0", "-", with "-" indicating a None value
returned by the CAS interface about the existence of its elementary antiderivative.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 scripts1/data_gen/antiderivative/generate_ead_indicator_wolfram.py \
    --input_file=/tmp/test_expressions \
    --output_file=/tmp/test_expressions_ead \
    --seconds_per_attempt=5 \
    --max_evaluators=8 \
    --examples_per_chunk=32


"""
import collections
import csv
import logging
import itertools
import os
from typing import Sequence

from absl import app
from absl import flags

import sympy as sp
from sympy.parsing import mathematica
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm

from wolframclient.evaluation import parallel_evaluate

from em.datasets.antiderivative import sympy_util as sp_util


FLAGS = flags.FLAGS

flags.DEFINE_string(
    "input_file",
    None,
    "Path to file to read expressions from. Should be in the format of 1 expression per line.",
)
flags.DEFINE_string("output_file", None, "Path to csv file to write expressions to.")

flags.DEFINE_string("wolfram_kernel_path", None, "")

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)

flags.DEFINE_integer("seconds_per_attempt", 1, '')
flags.DEFINE_integer("max_evaluators", 4, '')

flags.DEFINE_integer("examples_per_chunk", 2048, '')


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


def make_wolfram_input(expr, d_variable_name: str) -> str:
    m_expr = sp.mathematica_code(expr)
    # The ToString[] is needed for the result to actually be returned in InputForm.
    # return f'ToString[InputForm[TimeConstrained[Integrate[{m_expr}, {d_variable_name}], {FLAGS.seconds_per_attempt}]]]'
    return f'TimeConstrained[ToString[InputForm[Integrate[{m_expr}, {d_variable_name}]]], {FLAGS.seconds_per_attempt}]'


def read_chunk_of_expressions(fi, symbols):
    exprs = []
    for line in itertools.islice(fi, FLAGS.examples_per_chunk):
        expr = parse_expr(line.strip(), evaluate=True, local_dict=symbols)
        exprs.append(expr)
    return exprs


def get_chunks_iterator(fi, symbols):
    examples_per_chunk = FLAGS.examples_per_chunk
    chunk = []

    for line in fi:
        expr = parse_expr(line.strip(), evaluate=True, local_dict=symbols)
        chunk.append(expr)
        if len(chunk) == examples_per_chunk:
            yield chunk
            chunk = []

    if len(chunk) > 0:
        yield chunk


def get_label(result):
    result = str(result)
    if result == '$Aborted':
        return '-'
    else:
        try:
            result_expr = mathematica.mathematica(result)
        except ValueError:
            # This likely implies that the integral is non-elementary.
            return '0'

    is_elem = sp_util.is_elementary(result_expr)
    return '1' if is_elem else '0'


def to_csv_rows(chunk, wolfram_results):
    assert len(chunk) == len(wolfram_results)
    return [
        (str(expr), get_label(result))
        for expr, result in zip(chunk, wolfram_results)
    ]
    
        
def main(_):
    logging.getLogger().setLevel(logging.ERROR)

    symbols = make_symbols_dict(FLAGS.symbols)
    d_variable_name = FLAGS.symbols[0]

    input_file = os.path.expanduser(FLAGS.input_file)
    output_file = os.path.expanduser(FLAGS.output_file)

    with open(input_file, 'r') as fi, open(output_file, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        for chunk in get_chunks_iterator(tqdm(fi), symbols):
            wolfram_expressions = [
                make_wolfram_input(expr, d_variable_name)
                for expr in chunk
            ]
            wolfram_results = parallel_evaluate(
                wolfram_expressions,
                FLAGS.wolfram_kernel_path,
                max_evaluators=FLAGS.max_evaluators)

            rows = to_csv_rows(chunk, wolfram_results)
            writer.writerows(rows)
            fo.flush()
            os.fsync(fo)


# def main(_):
#     symbols = make_symbols_dict(FLAGS.symbols)
#     d_variable_name = FLAGS.symbols[0]

#     # TODO: Don't read the whole input file into memory.
#     #
#     # NOTE: We read everything into memory, which could cause issues for
#     # large input files.
#     exprs = []
#     input_file = os.path.expanduser(FLAGS.input_file)
#     with open(input_file, 'r') as fi:
#         for line in fi:
#             expr = parse_expr(line.strip(), evaluate=True, local_dict=symbols)
#             exprs.append(expr)

#     n_examples = len(exprs)

#     wolfram_expressions = [
#         make_wolfram_input(expr, d_variable_name)
#         for expr in exprs
#     ]

#     # TODO: At least break the expressions up into chunks when doing this. Or write
#     # code to handle results as they appear.
#     wolfram_results = parallel_evaluate(wolfram_expressions, max_evaluators=FLAGS.max_evaluators)

#     for result in wolfram_results:
#         result = str(result)
#         if result == '$Aborted':
#             result_expr = None
#         else:
#             try:
#                 result_expr = mathematica.mathematica(result)
#             except ValueError:
#                 # This likely implies that the integral is non-elementary.
#                 result_expr = None

#         if result_expr is not None:
#             print(sp_util.is_elementary(result_expr), result_expr)
#         print(result_expr)


if __name__ == "__main__":
    app.run(main)
