R"""Generates a dataset for a synthetic component dataset.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 scripts1/data_gen/antiderivative/generate_scomp_ds.py \
    --scomp=exex \
    --expressions_csv=~/Desktop/projects_data/extract_merge1/antiderivative/datasets/expressions001_ead.3M.04.5s.csv \
    --output_file=/tmp/test_exex.csv \
    --n_examples=10_000

"""
import csv
import os
import sys

from absl import app
from absl import flags
import sympy as sp
from tqdm import tqdm

from em.datasets.antiderivative.scomps import exex
from em.util.color_util import cu


FLAGS = flags.FLAGS

flags.DEFINE_enum('scomp', None, ['exex'], '')

flags.DEFINE_string("expressions_csv", None, "Path to csv file containing labeled expressions.")

flags.DEFINE_string("output_file", None, "Path to csv file to write output to.")

flags.DEFINE_integer("n_examples", None, 'Number of examples to generate.')

flags.DEFINE_integer("flush_every", 128, 'Flush output buffer after writing this many lines.')


def main(_):
    assert FLAGS.scomp == 'exex'

    # NOTE: Only really needed for local testing due to the files I have on hand.
    csv.field_size_limit(sys.maxsize)

    # TODO: Allow these parameters to be set. Probably just add a bunch of flags.
    gen = exex.ExexGenerator(
        expressions_source=exex.CsvExpressionsSource(FLAGS.expressions_csv),
        p_constant_type={
            'small_int': 0.5,
            'moderate_int': 0.2,
            'rational': 0.15,
        }
    )

    output_file = os.path.expanduser(FLAGS.output_file)
    with open(output_file, 'w', newline='') as fo:
        writer = csv.writer(fo, delimiter=',')
        for i in tqdm(range(FLAGS.n_examples)):
            try:
                expr, label = gen.sample_expression()
                writer.writerow((str(expr), str(label)))
            except (Exception, BaseException) as e:
                print(cu.hlr(e))
                continue
    
            if ((i + 1) % FLAGS.flush_every) == 0:
                fo.flush()
                os.fsync(fo)


if __name__ == "__main__":
    app.run(main)
