import os
import json
import time
from z3 import Solver, sat, Not
import argparse
from src.xlogominidatagen.code_synthesizer import CodeSyn
from src.xlogomini.components.world.world import World
from src.xlogomini.components.code.xlogo_code import Code
from src.xlogomini.components.constraints.code_constraints import CodeConstraints
from src.xlogomini.utils.formulas import exactly_the_same
from src.xlogomini.utils.model_conversions import model2values
from src.xlogomini.utils.load_data import load_code_json, load_cons_json, load_world_json

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--task_id', type=str, help='', default="9a")
    parser.add_argument('--log_interval', type=int, help='', default=10000)
    parser.add_argument('--save', action='store_true', help='')
    parser.add_argument('--exact_code_inc', type=int, help='', default=0)
    parser.add_argument('--grid_size_inc', type=int, help='', default=0)
    parser.add_argument('--n_max', type=int, help='', default=20000)
    args = parser.parse_args()

    if args.exact_code_inc <= 2:
        max_cons_inc = 0
    else:
        max_cons_inc = 1

    code_json = load_code_json(args.task_id)
    cons_json = load_cons_json(args.task_id)
    world = World.init_from_json(load_world_json(args.task_id))

    code_syn = CodeSyn(code_js=code_json, cons_js=cons_json)
    code_syn.mutate(n_blks_insert_hetero=2,
                    n_blks_insert_homog=1,
                    prob_insert_rep=0)

    s = Solver()
    # rows and cols are used to avoiding the pattern like "fd fd ...",
    # which would let the turtle go out of the grid. Here I use "rows/cols + 2" to make it looser than it should be.
    s.add(code_syn.properties(rows=world.rows + args.grid_size_inc,
                              cols=world.cols + args.grid_size_inc,
                              max_code_inc=4,
                              max_code_dec=0,
                              exact_code_inc=args.exact_code_inc,
                              max_rep_body_inc=1,
                              max_rep_body_dec=1,
                              max_rep_times_inc=1,
                              max_rep_times_dec=1,
                              max_cons_dec=0,
                              max_cons_inc=max_cons_inc))
    print("--- Ref Code ---")
    print(Code(code_json))
    print("----------------")

    cnt = 0
    code_set = set()
    code_list = []
    start_time = time.time()
    while s.check() == sat and len(code_list) < args.n_max:
        cnt += 1
        model_values = model2values(code_syn.vars, s.model())
        s.add(Not(exactly_the_same(code_syn.vars, model_values)))

        code_cons_json = code_syn.to_json(model_values)
        if str(code_cons_json) not in code_set:  # not exists
            code_set.add(str(code_cons_json))
            code_list.append(code_cons_json)  # save to the list

        if cnt % args.log_interval == 0:
            print(f"{cnt} codes synthesized!")
            code_cons = code_syn.to_json(model_values)
            print(Code(code_cons['code_json']))
            print(CodeConstraints(code_cons['constraints']))
            print("--------------------")

    if args.save:
        if not os.path.exists('./data/code'):
            os.makedirs('./data/code')

        # remove the duplicated codes (only consider the code equivalence, don't consider cons)
        # if only 1 available generated codes, then keep it
        code_list_without_ref_code = [cout for cout in code_list
                                      if str(cout) != str({"code_json": code_json, "constraints": cons_json}) or
                                      len(code_list) <= 1]

        # save to file
        with open(f'./data/code/{args.task_id}_cinc={args.exact_code_inc}.json', 'w') as f:
            json.dump(code_list_without_ref_code, f)

    print(f"Time spent: {time.time() - start_time}")
    print(f"Total {cnt} codes synthesized for task {args.task_id}")
    print(f"Non-duplicated codes: {len(code_set)}")
