R"""Cleans up EAD datasets, fixing some known issues.

Does the following:
- Ensures rational functions always say they have an elementary antiderivative.
- Removes examples with infinities and NaNs.
- Deduplicates examples.
- Eliminates overlap between splits.
- Can combine examples from multiple files into one.

"""
import collections
import csv
import os
from typing import Sequence
import sys

from absl import app
from absl import flags

import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

from tqdm import tqdm

from em.datasets.antiderivative import sympy_util as sp_util
from em.util.color_util import cu


FLAGS = flags.FLAGS


# Note that not all of these have to set.
flags.DEFINE_list('train_files', [], 'Path to files to create test split from.')
flags.DEFINE_list('validation_files', [], 'Path to files to create validation split from.')
flags.DEFINE_list('test_files', [], 'Path to files to create test split from.')

flags.DEFINE_string(
    "output_path",
    None,
    "Path to csv file to write expressions to. Must end with .csv. A /path/to/filename.csv will "
    "be written as /path/to/filename.$SPLIT.csv."
)

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)


flags.DEFINE_integer('max_expr_length', 16 * 1024, '')


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


def read_in_files(filepaths: Sequence[str]):
    ret = []
    for filepath in filepaths:
        with open(os.path.expanduser(filepath), 'r', newline='') as f:
            reader = csv.reader(f)
            for row in reader:
                ret.append(row)
    return ret


def clean_up_rows(rows, blacklist_exp_strs, symbols_dict, d_variable):
    expr_str_to_label = collections.OrderedDict()

    for expr_str, label in tqdm(rows):
        if len(expr_str) > FLAGS.max_expr_length:
            continue

        if expr_str in blacklist_exp_strs or expr_str in expr_str_to_label:
            continue

        try:
            expr = parse_expr(expr_str, evaluate=True, local_dict=symbols_dict)

            if sp_util.has_inf_nan(expr):
                continue

            if expr.is_rational_function(d_variable):
                label = '1'

        except Exception as e:
            print(cu.hlr(e))
            continue

        expr_str_to_label[expr_str] = label

    return list(expr_str_to_label.items())


def write_rows(split, rows):
    assert FLAGS.output_path.endswith('.csv')

    filepath = f'{FLAGS.output_path[:-4]}.{split}.csv'
    filepath = os.path.expanduser(filepath)

    with open(filepath, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        writer.writerows(rows)


def main(_):
    assert FLAGS.output_path.endswith('.csv')

    csv.field_size_limit(sys.maxsize)

    symbols_dict = make_symbols_dict(FLAGS.symbols)
    d_variable = list(symbols_dict.values())[0]

    seen_example_strs = set()

    def run_for_split(split: str):
        filepaths = getattr(FLAGS, f'{split}_files')
        if not filepaths:
            return

        og_rows = read_in_files(filepaths)
        new_rows = clean_up_rows(og_rows, seen_example_strs, symbols_dict, d_variable)

        seen_example_strs.update(r[0] for r in new_rows)

        write_rows(split, new_rows)

    run_for_split('train')
    run_for_split('validation')
    run_for_split('test')


if __name__ == "__main__":
    app.run(main)


"""
import collections
import csv
import multiprocessing as mp
import os
from typing import Sequence
import sys

from absl import app
from absl import flags

import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

from tqdm import tqdm

from em.datasets.antiderivative import sympy_util as sp_util

from em.util.color_util import cu


FLAGS = flags.FLAGS


# Note that not all of these have to set.
flags.DEFINE_list('train_files', [], 'Path to files to create test split from.')
flags.DEFINE_list('validation_files', [], 'Path to files to create validation split from.')
flags.DEFINE_list('test_files', [], 'Path to files to create test split from.')

flags.DEFINE_string(
    "output_path",
    None,
    "Path to csv file to write expressions to. Must end with .csv. A /path/to/filename.csv will "
    "be written as /path/to/filename.$SPLIT.csv."
)

flags.DEFINE_list(
    'symbols',
    ['x'],
    'The names of symbols when reading from sympy. The first will be the `d_variable`.'
)


flags.DEFINE_integer('max_expr_length', 16 * 1024, '')

flags.DEFINE_integer("n_processes", 1, '')


class ParsingProcess(mp.Process):
    def __init__(
        self,
        input_queue: mp.Queue,
        output_queue: mp.Queue,
        symbols: Sequence[str],
    ):
        super().__init__()
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.symbols_dict = make_symbols_dict(FLAGS.symbols)
        self.d_variable = list(self.symbols_dict.values())[0]

    def _process_expr_str(self, expr_str: str, label: str):
        expr = parse_expr(expr_str, evaluate=True, local_dict=self.symbols_dict)

        if sp_util.has_inf_nan(expr):
            return None

        if expr.is_rational_function(self.d_variable):
            label = '1'

        return expr_str, label

    def process_expr_str(self, expr_str: str, label: str):
        try:
            return self._process_expr_str(expr_str, label)
        except Exception as e:
            print(cu.hlr(e))
            return None

    def run(self):
        while True:
            expr_str, label = self.input_queue.get()
            results = self.process_expr_str(expr_str, label)
            self.output_queue.put(results)


def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )


def read_in_files(filepaths: Sequence[str]):
    ret = []
    for filepath in filepaths:
        with open(os.path.expanduser(filepath), 'r', newline='') as f:
            reader = csv.reader(f)
            for row in reader:
                ret.append(row)
    return ret


def clean_up_rows(rows, blacklist_exp_strs, input_queue):
    seen_expr_strs = set()

    n_passed_to_queue = 0

    for expr_str, label in tqdm(rows):
        if len(expr_str) > FLAGS.max_expr_length:
            continue

        if expr_str in blacklist_exp_strs or expr_str in seen_expr_strs:
            continue

        seen_expr_strs.add(expr_str)

        n_passed_to_queue += 1
        input_queue.put((expr_str, label))

    return n_passed_to_queue


def write_rows(split, rows):
    assert FLAGS.output_path.endswith('.csv')

    filepath = f'{FLAGS.output_path[:-4]}.{split}.csv'
    filepath = os.path.expanduser(filepath)

    with open(filepath, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        writer.writerows(rows)


def main(_):
    assert FLAGS.output_path.endswith('.csv')

    csv.field_size_limit(sys.maxsize)

    seen_example_strs = set()

    #

    input_queue = mp.Queue()
    output_queue = mp.Queue()

    processes = [
        ParsingProcess(
            input_queue=input_queue,
            output_queue=output_queue,
            symbols=FLAGS.symbols,
        )
        for _ in range(FLAGS.n_processes)
    ]
    for p in processes:
        p.start()

    #

    def run_for_split(split: str):
        filepaths = getattr(FLAGS, f'{split}_files')
        if not filepaths:
            return

        og_rows = read_in_files(filepaths)
        n_passed_to_queue = clean_up_rows(og_rows, seen_example_strs, input_queue)

        new_rows = []
        for _ in tqdm(range(n_passed_to_queue)):
            response = output_queue.get()
            if response is None:
                continue
            expr_str, label = response
            new_rows.append([expr_str, label])
            seen_example_strs.add(expr_str)

        write_rows(split, new_rows)

    run_for_split('train')
    run_for_split('validation')
    run_for_split('test')

    for p in processes:
        p.terminate()


if __name__ == "__main__":
    app.run(main)

"""