from task import Task
import os
import itertools
import subprocess
from tokenizer import NumberTokenizer
from dataset import NumberDataset
import pickle
import tqdm


dataset_path = "benchmark/numbers/"
random_seed = 208171

dataset_list = [
    ["-d", "Integer", "-n", "100000", "--min_len", "1", "--max_len", "8", "--min_valid_len", "3", "--max_valid_len", "20", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "20", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Integer", "add"), "--skip_check_token", '--random_seed', str(random_seed+0)],
    ["-d", "Float", "-n", "100000", "--min_len", "1", "--max_len", "8", "--min_valid_len", "3", "--max_valid_len", "20", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "20", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Float", "add"), "--skip_check_token", '--random_seed', str(random_seed+1)],
    ["-d", "Fraction", "-n", "100000", "--min_len", "1", "--max_len", "8", "--min_valid_len", "1", "--max_valid_len", "20", "--valid_nums", "1000", "--min_test_len", "1", "--max_test_len", "20", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Fraction", "add"), "--skip_check_token", '--random_seed', str(random_seed+2)],
    ["-d", "ScientificNotation", "-n", "100000", "--min_len", "1", "--max_len", "8", "--min_valid_len", "3", "--max_valid_len", "20", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "20", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "ScientificNotation", "add"), "--skip_check_token", '--random_seed', str(random_seed+3)],

    # for some domain, the compare is much more easier than add, so we set a longer length for the test set.
    ["-d", "Integer", "-n", "100000", "--min_len", "1", "--max_len", "20", "--min_valid_len", "5", "--max_valid_len", "100", "--valid_nums", "1000", "--min_test_len", "5", "--max_test_len", "100", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Integer", "compare"), "--skip_check_token", '--random_seed', str(random_seed+4)],
    ["-d", "Float", "-n", "100000", "--min_len", "1", "--max_len", "20", "--min_valid_len", "3", "--max_valid_len", "100", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "100", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Float", "compare"), "--skip_check_token", '--random_seed', str(random_seed+5)],
    ["-d", "ScientificNotation", "-n", "100000", "--min_len", "1", "--max_len", "20", "--min_valid_len", "3", "--max_valid_len", "100", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "100", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "ScientificNotation", "compare"), "--skip_check_token", '--random_seed', str(random_seed+6)],

    # A harder version compare, where numbers share more same digit
    ["-d", "Integer", "-n", "100000", "--min_len", "1", "--max_len", "20", "--min_valid_len", "5", "--max_valid_len", "100", "--valid_nums", "1000", "--min_test_len", "5", "--max_test_len", "100", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Integer", "compare_harder"), "--skip_check_token", "--harder_compare", "--same_len", 'true', '--random_seed', str(random_seed+7)],
    ["-d", "Float", "-n", "100000", "--min_len", "1", "--max_len", "20", "--min_valid_len", "3", "--max_valid_len", "100", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "100", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "Float", "compare_harder"), "--skip_check_token", "--harder_compare", '--random_seed', str(random_seed+8)],
    ["-d", "ScientificNotation", "-n", "100000", "--min_len", "1", "--max_len", "20", "--min_valid_len", "3", "--max_valid_len", "100", "--valid_nums", "1000", "--min_test_len", "3", "--max_test_len", "100", "--test_nums", "1000", "--test_shorter_len", "0.5", "--save_path", os.path.join(dataset_path, "ScientificNotation", "compare_harder"), "--skip_check_token", "--harder_compare", '--random_seed', str(random_seed+9)],
]

task_dataset_list = [
    (Task("add", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    (Task("add", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "add")),
    (Task("add", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("add_easy", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("add", "ScientificNotation", "ScientificNotation", "ScientificNotation"), os.path.join(dataset_path, "ScientificNotation", "add")),
    
    (Task("sub", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    (Task("sub", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "add")),
    (Task("sub", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("sub", "ScientificNotation", "ScientificNotation", "ScientificNotation"), os.path.join(dataset_path, "ScientificNotation", "add")), # sub can use the same dataset as add, with a swap preprocess in dataset
    
    (Task("max", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "compare")), 
    (Task("max", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "compare")),
    (Task("max", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")), # compare betwee fraction is hard, use a smaller dataset
    (Task("max", "ScientificNotation", "ScientificNotation", "ScientificNotation"), os.path.join(dataset_path, "ScientificNotation", "compare")),
    
    (Task("max_hard", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "compare_harder")), # two integers are generated to share some digits
    (Task("max_hard", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "compare_harder")), # two floats are more likely to share the int part and some digits 
    (Task("max_hard", "ScientificNotation", "ScientificNotation", "ScientificNotation"), os.path.join(dataset_path, "ScientificNotation", "compare_harder")), # more likely to share the same the exponent and some digits 
    
    (Task("multiply_hard", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    (Task("multiply_hard", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "add")),
    (Task("multiply_hard", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("multiply_hard", "ScientificNotation", "ScientificNotation", "ScientificNotation"), os.path.join(dataset_path, "ScientificNotation", "add")),
    
    (Task("multiply_easy", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    (Task("multiply_easy", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "add")),
    (Task("multiply_easy", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("multiply_easy", "ScientificNotation", "ScientificNotation", "ScientificNotation"), os.path.join(dataset_path, "ScientificNotation", "add")),
    
    (Task("digit_max", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "compare")),
    (Task("digit_max", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "compare")),
    
    (Task("digit_add", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "compare")),
    (Task("digit_add", "Float", "Float", "Float"), os.path.join(dataset_path, "Float", "compare")),
    
    (Task("get_digit", "Integer", "int", "int"), os.path.join(dataset_path, "Integer", "compare")),
    (Task("get_digit", "Float", "int", "int"), os.path.join(dataset_path, "Float", "compare")),
    
    (Task("length", "Integer", "none", "int"), os.path.join(dataset_path, "Integer", "compare")),
    (Task("length", "Float", "none", "int"), os.path.join(dataset_path, "Float", "compare")),
    
    (Task("truediv", "Integer", "Integer", "Fraction"), os.path.join(dataset_path, "Integer", "add")),
    (Task("truediv", "Fraction", "Fraction", "Fraction"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("floordiv", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    (Task("mod", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    (Task("mod_easy", "Integer", "Integer", "Integer"), os.path.join(dataset_path, "Integer", "add")),
    
    (Task("to_float", "Fraction", "none", "Float"), os.path.join(dataset_path, "Fraction", "add")),
    (Task("to_float", "ScientificNotation", "none", "Float"), os.path.join(dataset_path, "ScientificNotation", "compare")),
    
    (Task("to_scient", "Integer", "none", "ScientificNotation"), os.path.join(dataset_path, "Integer", "compare")),
    (Task("to_scient", "Float", "none", "ScientificNotation"), os.path.join(dataset_path, "Float", "compare")),
    
    (Task("count", "Integer", "int", "int"), os.path.join(dataset_path, "Integer", "compare")),
    
    (Task("sig", "Integer", "int", "ScientificNotation"), os.path.join(dataset_path, "Integer", "compare")),
    (Task("sig", "Float", "int", "ScientificNotation"), os.path.join(dataset_path, "Float", "compare")),
]

def record(task_name: str, digit: int, string: str, dict_: dict[str, dict[int, list[str]]]) -> None:
    if task_name not in dict_:
        dict_[task_name] = {}
    if digit not in dict_[task_name]:
        dict_[task_name][digit] = []
    dict_[task_name][digit].append(string)
    
def process(task, dataset_dir, name, subprocess_idx, continue_, reverse, pad, cache_path):
    task_name = "_".join((str(task)[5:-1]).split(", "))
    if continue_ and os.path.exists(os.path.join(cache_path, task_name, name[:-4]+".pkl")):
        with open(os.path.join(cache_path, task_name, name[:-4]+".pkl"), "rb") as f:
            return pickle.load(f)
    tokenizer = NumberTokenizer(task=task, reverse_rep = reverse, random_seed = random_seed + subprocess_idx * 8198, number_pad=pad)
    f = open(os.path.join(dataset_dir, name), "rb")
    dataset = NumberDataset(pickle.load(f), tokenizer=tokenizer, task=task, training=True, trunc=None, return_numbers=True)
    target_dict = {}
    for d in tqdm.tqdm(dataset, desc=f"Solving {str(task)} {name[:-4]}", leave=False):
        tokens = d["tokens"]
        input_number = d["numbers"]
        # input_number = tokenizer.recover_input_number(tokens)
        if isinstance(input_number, tuple):
            digit = max(input_number[0].digit, 0 if isinstance(input_number[1], int) else input_number[1].digit)
        else:
            digit = input_number.digit
        str_ = tokenizer.export(tokens)
        record(
            task_name=task_name, 
            digit=digit, 
            string=str_, 
            dict_=target_dict
            )
    task_name = "_".join((str(task)[5:-1]).split(", "))
    os.makedirs(os.path.join(cache_path, task_name), exist_ok=True)
    with open(os.path.join(cache_path, task_name, name[:-4]+".pkl"), "wb") as wf:
        pickle.dump((name, subprocess_idx, target_dict), wf)
    return name, subprocess_idx, target_dict
    
def main() -> None:
    import argparse
    import json
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", action="store_true", help="generate dataset")
    parser.add_argument("-t", "--task", action="store_true", help="generate tasks")
    parser.add_argument('-r', "--reverse", type=str, default='no', help="reverse the number string representation.")
    parser.add_argument('-p', '--pad', action="store_true")
    parser.add_argument("-c", "--continue_", action="store_true", help="continue the last task")
    args = parser.parse_args()
    
    save_path = "benchmark/tasks" + ("_pad" if args.pad else "") + ("_reverse" if args.reverse != "no" else "") + "/"
    cache_path = "benchmark/cache" + ("_pad" if args.pad else "") + ("_reverse" if args.reverse != "no" else "") + "/"
    
    if args.dataset:
        processes = []
        for dataset in dataset_list:
            path = dataset[dataset.index("--save_path")+1]
            if args.continue_:
                if os.path.exists(path):
                    print("Skip", path)
                    continue
            print("Generating dataset", path)
            processes.append(subprocess.Popen("python data_generate.py " + " ".join(dataset), env=os.environ.copy(), shell=True))
    # wait for all the dataset is generated
        for p in processes:
            p.wait()
        print("All dataset is generated.")
    
    
    # if args.task:
    if args.task:
        import multiprocessing
        
        with multiprocessing.Pool(processes=3 * len(task_dataset_list)) as pool:
            task_list = list(map(lambda x: (x[1][0][0], x[1][0][1], x[1][1], x[0], args.continue_, args.reverse, args.pad, cache_path), enumerate(itertools.product(
                task_dataset_list, ["train.pkl", "valid.pkl", "test.pkl"]
                # task_dataset_list, ["test.pkl"]
            ))))
            results = pool.starmap(process, task_list)
            
        results.sort(key = lambda x: x[1])
        
        train_dict: dict[str, dict[int, list[str]]] = {}
        valid_dict: dict[str, dict[int, list[str]]] = {}
        test_dict: dict[str, dict[int, list[str]]] = {}
        
        
        
        for result in results:
            name, _, target_dict = result
            if name.endswith("train.pkl"):
                for task_name, digit_dict in target_dict.items():
                    if task_name not in train_dict:
                        train_dict[task_name] = {}
                    for digit, string_list in digit_dict.items():
                        if digit not in train_dict[task_name]:
                            train_dict[task_name][digit] = []
                        train_dict[task_name][digit] += string_list
            elif name.endswith("valid.pkl"):
                for task_name, digit_dict in target_dict.items():
                    if task_name not in valid_dict:
                        valid_dict[task_name] = {}
                    for digit, string_list in digit_dict.items():
                        if digit not in valid_dict[task_name]:
                            valid_dict[task_name][digit] = []
                        valid_dict[task_name][digit] += string_list
            elif name.endswith("test.pkl"):
                for task_name, digit_dict in target_dict.items():
                    if task_name not in test_dict:
                        test_dict[task_name] = {}
                    for digit, string_list in digit_dict.items():
                        if digit not in test_dict[task_name]:
                            test_dict[task_name][digit] = []
                        test_dict[task_name][digit] += string_list
                    
        os.makedirs(save_path, exist_ok=True)
        with open(os.path.join(save_path, "train.json"), "w") as wf:
            json.dump(train_dict, wf, indent=2)
        with open(os.path.join(save_path, "valid.json"), "w") as wf:
            json.dump(valid_dict, wf, indent=2)
        with open(os.path.join(save_path, "test.json"), "w") as wf:
            json.dump(test_dict, wf, indent=2)
    print("Done!")
    
if __name__ == "__main__":
    main()
