R"""Determine whether a set of expressions have elementary antiderivatives.

The generated CSV has two columns: Expression (in sympy format), Has EAD.
The "Has EAD" column can be "1", "0", "-", with "-" indicating a None value
returned by the CAS interface about the existence of its elementary antiderivative.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 scripts1/data_gen/antiderivative/generate_ead_indicator_wolfram3.py \
    --input_file=/tmp/test_expressions \
    --output_file=/tmp/test_expressions_ead \
    --seconds_per_attempt=1 \
    --max_evaluators=4 \
    --examples_per_chunk=8 \
    --flush_every=4


# Run this if I need to clean up:
pkill -9 -f WolframKernel


"""
import asyncio
import collections
import csv
import itertools
import json
import logging
import multiprocessing as mp
import os
import subprocess
from typing import List, Sequence

from absl import app
from absl import flags

import sympy as sp
from sympy.parsing import mathematica
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm

from em.datasets.antiderivative import sympy_util as sp_util
from em.util.color_util import cu


FLAGS = flags.FLAGS

flags.DEFINE_string(
    "input_file",
    None,
    "Path to file to read expressions from. Should be in the format of 1 expression per line.",
)
flags.DEFINE_string("output_file", None, "Path to csv file to write expressions to.")

flags.DEFINE_string("wolfram_kernel_path", "/usr/local/bin/WolframKernel", "")

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)

flags.DEFINE_float("seconds_per_attempt", 1.0, '')
flags.DEFINE_float("timeout_safety_factor", 1.25, '')

flags.DEFINE_integer("max_evaluators", 4, '')

flags.DEFINE_integer("examples_per_chunk", 8, '')

flags.DEFINE_integer("flush_every", 4, 'Note that this is in terms of chunks instead of examples.')


###################################################################################


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


def get_chunks_iterator(fi):
    examples_per_chunk = FLAGS.examples_per_chunk
    chunk = []

    for line in fi:
        chunk.append(line.strip())
        if len(chunk) == examples_per_chunk:
            yield chunk
            chunk = []

    if len(chunk) > 0:
        yield chunk


###################################################################################


WOLFRAM_FAILEXPR = '$Aborted'


class EadProcess(mp.Process):

    def __init__(
        self,
        input_queue: mp.Queue,
        output_queue: mp.Queue,
        wolfram_kernel_path: str,
        symbols: Sequence[str],
        seconds_per_attempt: float,
        safety_factor: float,
    ):
        super().__init__()
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.wolfram_kernel_path = os.path.expanduser(wolfram_kernel_path)
        self.seconds_per_attempt = seconds_per_attempt
        self.symbols = make_symbols_dict(symbols)
        self.d_variable_name = list(self.symbols.keys())[0]
        self.safety_factor = safety_factor

        # NOTE: Using this parser will result in the parsed statement NOT
        # being equivalent to the mathematica statement. However, its status
        # of being elementary will be the same.
        self.mathematica_parser = mathematica.MathematicaParser({
            'RootSum[x,y]': '(y)',
        })

    #######################################################

    def _make_integral_expression(self, expr_str: str) -> str:
        expr = parse_expr(expr_str, evaluate=True, local_dict=self.symbols)
        m_expr = sp.mathematica_code(expr)

        wolfram_cmd = f'ToString[InputForm[FunctionExpand[Integrate[{m_expr}, {self.d_variable_name}]]]]'
        return f'TimeConstrained[{wolfram_cmd}, {self.seconds_per_attempt}, "{WOLFRAM_FAILEXPR}"]'

    def _make_integral_expressions(self, expr_strs: Sequence[str]) -> str:
        middle = ", ".join(self._make_integral_expression(expr) for expr in expr_strs)
        return 'Print[{' + middle + '}];Exit[]'

    def _make_command(self, expr_strs: Sequence[str]) -> List[str]:
        n_expressions = len(expr_strs)
        timeout_secs = self.safety_factor * n_expressions * self.seconds_per_attempt
        return [
            # The timeout commands.
            'timeout',
            '--signal=SIGKILL',
            f'{timeout_secs}s',
            #
            # The WolframKernel commands.
            self.wolfram_kernel_path,
            '-noprompt',
            '-run',
            self._make_integral_expressions(expr_strs),
        ]

    #######################################################

    def _label_from_integral(self, integral):
        if integral == WOLFRAM_FAILEXPR:
            return '-'
        try:
            # result_expr = mathematica.mathematica(integral)
            expr_str = self.mathematica_parser.parse(integral.replace('#1', 'r').replace('&', ' '))
            result_expr = sp.sympify(expr_str)
        except ValueError:
            # This likely implies that the integral is non-elementary.
            return '0'
        is_elem = sp_util.is_elementary(result_expr)
        return '1' if is_elem else '0'

    def _create_labels_from_stdout(self, stdout: str):
        # Assumes that the returncode was 0, i.e. the command completed successfully.
        results = stdout.strip().split('\n')[-1]
        assert results[0] == '{' and results[-1] == '}'
        integrals = json.loads(f'[{results[1:-1]}]')
        return [self._label_from_integral(integral) for integral in integrals]

    #######################################################

    def _process_inputs(self, expr_strs: Sequence[str]):
        cmd = self._make_command(expr_strs)
        called = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

        if called.returncode == 0:
            # Valid result.
            labels = self._create_labels_from_stdout(called.stdout)

        elif called.returncode == -9:
            # Timedout
            print(cu.hly('The call to WolframKernel timed out.'))
            # TODO: Add binary-tree-style retry when we get a timeout?
            # Or just redo this better with sending messages to running WolframKernel.
            labels = len(expr_strs) * ['-']

        elif called.returncode == -11:
            # Call dumped core. This is usually stuff from Mathematica like:
            # FactorSquareFree::lrgexp: Exponent is out of bounds for function FactorSquareFree.
            # General::nomem: The current computation was aborted because there was insufficient memory available to complete the computation.
            print(cu.hly('The call to WolframKernel dumped core.'))
            labels = len(expr_strs) * ['-']
        
        else:
            print(cu.hlr(f'Unrecognized return code {called.returncode}.'))
            print(cu.hlr(f'Stdout: {called.stdout}'))
            print(cu.hlr(f'Stderr: {called.stderr}'))
            labels = len(expr_strs) * ['-']

        assert len(expr_strs) == len(labels)

        return list(zip(expr_strs, labels))

    def process_inputs(self, expr_strs: Sequence[str]):
        try:
            return self._process_inputs(expr_strs)

        except Exception as e:
            print(cu.hlr('Exception thrown when processing inputs:'))
            print(cu.hlr(str(e)))

            labels = len(expr_strs) * ['-']
            return list(zip(expr_strs, labels))

    #######################################################

    def run(self):
        while True:
            expr_strs = self.input_queue.get()
            results = self.process_inputs(expr_strs)
            self.output_queue.put(results)


###################################################################################


def main(_):
    input_file = os.path.expanduser(FLAGS.input_file)
    output_file = os.path.expanduser(FLAGS.output_file)

    #
    
    input_queue = mp.Queue()
    output_queue = mp.Queue()

    processes = [
        EadProcess(
            input_queue=input_queue,
            output_queue=output_queue,
            wolfram_kernel_path=FLAGS.wolfram_kernel_path,
            symbols=FLAGS.symbols,
            seconds_per_attempt=FLAGS.seconds_per_attempt,
            safety_factor=FLAGS.timeout_safety_factor,
        )
        for _ in range(FLAGS.max_evaluators)
    ]
    for p in processes:
        p.start()

    #

    # TODO: Don't read everything in at once, probably.
    n_chunks = 0
    with open(input_file, 'r') as fi:
        for expr_strs in get_chunks_iterator(tqdm(fi)):
            input_queue.put(expr_strs)
            n_chunks += 1

    #

    with open(output_file, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        for i in tqdm(range(n_chunks)):
            rows = output_queue.get()
            writer.writerows(rows)
    
            if ((i + 1) % FLAGS.flush_every) == 0:
                fo.flush()
                os.fsync(fo)

    #

    for p in processes:
        p.terminate()


if __name__ == "__main__":
    app.run(main)
