#!/usr/bin/env python3

# TODO:
# - Add option to include new functions in ALL_FUNCS

import listops

# import listops.data.functions as fs
from listops.data import ListOpsDataset, functions as fs
import argparse

parser = argparse.ArgumentParser(description='Generate ListOps dataset.')
parser.add_argument('--save_dir', type=str, default='../data/listops/', help='Directory to save the dataset')
parser.add_argument('--num_train', type=int, default=20000, help='Number of training examples')
parser.add_argument('--num_test', type=int, default=2000, help='Number of test examples')
parser.add_argument('--max_depth', type=int, default=3, help='Maximum depth of the tree')
parser.add_argument('--min_children', type=int, default=2, help='Minimum number of children for a node')
parser.add_argument('--max_children', type=int, default=3, help='Maximum number of children for a node')
parser.add_argument('--func_node_prob', type=float, default=0.25, help='Probability of a node being a function node')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
# if Polish is True, we will use the Polish notation for the dataset
parser.add_argument('--Polish', action='store_true', help='Whether to use Polish notation for the dataset')
parser.add_argument('--noperm', action='store_true', help='Whether to exclude permutations of the input sequences')
parser.add_argument('--mod', type=int, default=10, help='Modulus for operations')
parser.add_argument('--input_vals', type=int, nargs='+', default=[], help='(Optional if mod not used) Input values for the dataset')
# list of functions to use
parser.add_argument('--funcs_to_use', '--ops', type=str, nargs='+', default=['max', 'median', 'add'], help='Functions to use in the dataset')


args = parser.parse_args()

MOD = args.mod
INPUT_VALS= list(range(MOD)) if not args.input_vals else args.input_vals
ALL_FUNCS = [min, max, fs.median, fs.add, fs.prod, fs.nadd, fs.parmax]
# first check that funcs_to_use are in ALL_FUNCS
all_func_names = {f.__name__:i for i,f in enumerate(ALL_FUNCS)}
funcs_to_use_idx = []
for func in args.funcs_to_use:
    # find it in the list of function names
    i = all_func_names.get(func)
    if i is not None:
        # print(func, i, ALL_FUNCS[i])
        funcs_to_use_idx.append(i)
    else:
        raise ValueError(f"Function {func} is not in the list of available functions: {ALL_FUNCS}") 
        
    # we can have the option to add the function to the list of available functions
# Now it is safe to evaluate the functions to use
ALL_FUNCS = fs.get_funcs_mod(ALL_FUNCS, MOD)
# FUNCS_TO_USE = fs.get_funcs_mod(FUNCS_TO_USE, MOD)
FUNCS_TO_USE = [ALL_FUNCS[i] for i in funcs_to_use_idx]

dataset = ListOpsDataset(
    max_depth=args.max_depth,
    min_children=args.min_children,
    max_children=args.max_children,
    input_set=INPUT_VALS,
    func_node_prob=args.func_node_prob,
    all_funcs=ALL_FUNCS, 
    funcs_to_use=FUNCS_TO_USE, 
    polish_notation=args.Polish,
    exclude_permutation=args.noperm,  # we exclude permutation to avoid duplicates
    seed=args.seed,
)

dataset.prepare_train_test_data(num_train=args.num_train, 
                                num_test=args.num_test,)

dataset.save(save_dir=args.save_dir)
# dataset.save_hdf5(save_dir=args.save_dir)