import argparse
import json
import os
import random

from tqdm import trange, tqdm

from src.template_utils.template import LogicalTemplate, generate_sample, generate_sample_complex, \
    ComplexLogicalTemplate

def main(args):
    # set random seed
    random.seed(args.seed)

    max_objnum = args.max_objnum
    max_cnum = args.max_cnum
    max_pnum = args.max_pnum
    rule_window = args.rule_window
    saved_path = args.save_path

    assert not os.path.exists(saved_path), f"{saved_path} already exists"

    dataset = []
    for i in trange(args.i_start, args.i_end):
        for j in trange(i + args.i_add_start, i + args.i_add_end):
            lts = []
            for _ in range(args.generate_num):
                sample = generate_sample_complex(i, j, max_cnum=max_cnum, max_pnum=max_pnum, rule_window=rule_window,
                                              max_objnum=max_objnum)
                template = ComplexLogicalTemplate(
                    **sample, max_cnum=max_cnum, max_pnum=max_pnum,
                    max_objnum=max_objnum)

                # if template not in lts:
                lts.append(template)

            dataset += list(set(lts))

    # "logicalDatasets/complex_templates_1.jsonl"
    count = 0
    with open(saved_path, 'w') as f:
        for lt in tqdm(dataset):
            if lt.self_check_by_dlv():
                f.write(json.dumps(lt.__dict__()) + '\n')
                count += 1
    print(f"total {count} templates saved to {saved_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="building template samples")
    parser.add_argument('--max_objnum', type=int, help='The max number of objects', default=1)
    parser.add_argument('--max_cnum', type=int, help='The max number of concepts', default=5)
    parser.add_argument('--max_pnum', type=int, help='The max number of properties', default=1)
    parser.add_argument('--rule_window', type=int, help='The window size of rules', default=999)
    parser.add_argument('--seed', type=int, help='The random seed', default=42)
    parser.add_argument('--save_path', type=str, help='The path of saving the dataset', required=True)

    parser.add_argument('--i_start', type=int, help='The start number of i', default=5)
    parser.add_argument('--i_end', type=int, help='The end number of i', default=8)
    parser.add_argument('--i_add_start', type=int, help='The start number of j', default=3)
    parser.add_argument('--i_add_end', type=int, help='The end number of j', default=14)
    parser.add_argument('--generate_num', type=int, help='The number of samples to generate', default=100)

    args = parser.parse_args()
    print(args)
    main(args)
