R"""Determine whether a set of expressions have elementary antiderivatives.

The generated CSV has two columns: Expression (in sympy format), Has EAD.
The "Has EAD" column can be "1", "0", "-", with "-" indicating a None value
returned by the CAS interface about the existence of its elementary antiderivative.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 scripts1/data_gen/antiderivative/generate_ead_indicator_sympy.py \
    --input_file=/tmp/test_expressions \
    --output_file=/tmp/test_expressions_ead \
    --seconds_per_attempt=5 \
    --python_name=python3 \
    --kernel_max_runtime_seconds=120 \
    --n_processes=2


"""
import collections
import csv
import multiprocessing as mp
import os
import subprocess
from typing import Sequence

from absl import app
from absl import flags

import sympy as sp
from tqdm import tqdm

from em.util.color_util import cu

from scripts1.data_gen.antiderivative._has_ead_sympy_kernel import IS_READY_MSG


KERNEL_FILEPATH = os.path.join(os.path.dirname(__file__), '_has_ead_sympy_kernel.py')


FLAGS = flags.FLAGS

flags.DEFINE_string(
    "input_file",
    None,
    "Path to file to read expressions from. Should be in the format of 1 expression per line.",
)
flags.DEFINE_string("output_file", None, "Path to csv file to write expressions to.")

flags.DEFINE_string("python_name", 'python', "Name of python executable.")

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)

flags.DEFINE_integer("seconds_per_attempt", 1, '')
flags.DEFINE_integer("kernel_max_runtime_seconds", 1, '')

flags.DEFINE_integer("n_processes", 1, '')

flags.DEFINE_integer("flush_every", 50, '')

###################################################################################


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )

###################################################################################


class EadProcess(mp.Process):

    def __init__(
        self,
        input_queue: mp.Queue,
        output_queue: mp.Queue,
        symbols: Sequence[str],
        seconds_per_attempt: int,
        kernel_max_runtime_seconds: int,
    ):
        super().__init__()
        self.input_queue = input_queue
        self.output_queue = output_queue
        # self.wolfram_kernel_path = os.path.expanduser(wolfram_kernel_path)
        self.seconds_per_attempt = seconds_per_attempt
        self.symbols = make_symbols_dict(symbols)
        self.symbols_str = ','.join(symbols)
        self.d_variable = list(self.symbols.values())[0]
        # self.safety_factor = safety_factor
        self.kernel_max_runtime_seconds = kernel_max_runtime_seconds

        self._kernel_process = None

    def _make_command(self):
        return [
            # The timeout commands.
            'timeout',
            '--signal=SIGKILL',
            f'{self.kernel_max_runtime_seconds}s',
            #
            # The WolframKernel commands.
            FLAGS.python_name,
            KERNEL_FILEPATH,
            f'--seconds_per_attempt={self.seconds_per_attempt}',
            f'--symbols={self.symbols_str}',
        ]

    def _initialize_kernel(self):
        cmd = self._make_command()
        self._kernel_process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stdin=subprocess.PIPE,
            text=True,
        )

        # TODO: Threading or something instead of this?
        while True:
            line = self._kernel_process.stdout.readline().strip()
            if line.strip() == IS_READY_MSG:
                break

    def _process_input(self, expr_str: str):
        self._kernel_process.stdin.write(f'{expr_str}\n')
        self._kernel_process.stdin.flush()
        response = self._kernel_process.stdout.readline().strip()
        return response
        # TODO: Wrap this in a try catch.

    def process_input(self, expr_str: str):
        if self._kernel_process is None:
            self._initialize_kernel()

        try:
            return self._process_input(expr_str)
        except BrokenPipeError:
            pass

        # Try again, if everything is working properly, this
        # means that the timeout on the kernel expired.
        self._initialize_kernel()

        try:
            return self._process_input(expr_str)
        except BrokenPipeError:
            print(cu.hlr('Tried twice and got BrokenPipeError for example.'))
            return '-'
    
    def run(self):
        while True:
            expr_str = self.input_queue.get()
            label = self.process_input(expr_str)
            self.output_queue.put((expr_str, label))


###################################################################################


def main(_):
    
    input_file = os.path.expanduser(FLAGS.input_file)
    output_file = os.path.expanduser(FLAGS.output_file)

    #
    
    input_queue = mp.Queue()
    output_queue = mp.Queue()

    processes = [
        EadProcess(
            input_queue=input_queue,
            output_queue=output_queue,
            symbols=FLAGS.symbols,
            seconds_per_attempt=FLAGS.seconds_per_attempt,
            kernel_max_runtime_seconds=FLAGS.kernel_max_runtime_seconds,
        )
        for _ in range(FLAGS.n_processes)
    ]
    for p in processes:
        p.start()

    #

    # TODO: Don't read everything in at once, probably.
    n_examples = 0
    with open(input_file, 'r') as fi:
        for expr_str in tqdm(fi):
            input_queue.put(expr_str.strip())
            n_examples += 1

    #

    with open(output_file, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        for i in tqdm(range(n_examples)):
            row = output_queue.get()
            writer.writerow(row)
    
            if ((i + 1) % FLAGS.flush_every) == 0:
                fo.flush()
                os.fsync(fo)

    #

    for p in processes:
        p.terminate()


if __name__ == "__main__":
    app.run(main)
