import subprocess

import json
import os
from os.path import dirname, abspath
d = dirname(dirname(abspath(__file__)))
os.environ["PYTHONPATH"] = d


def run(file, **kwargs):
    call = ["python3", file, json.dumps(kwargs)]
    subprocess.run(call, env=os.environ)



if __name__ == "__main__":
    ws_norm = -1
    lr_net = 1e-4
    world_token_count = 32
    num_dig = 10
    num_inf = 1
    embedding_size = 64
    n_blocks = 2
    dim_feedforward = 2048
    n_heads = 4
    num_rel = 3
    u_n_pre_blocks = 2 #number of layers without cross attention with the input, before other layers
    u_n_blocks = 2
    u_n_post_blocks = 0 #number of layers without cross attention with the input, after other layers
    num_entities = 8
    world_state_size = embedding_size * world_token_count * n_heads
    u_world_token_count = world_token_count
    pos_enc = 8.0
    model_inits = dict([("world_state_size", world_state_size), ("n_blocks", n_blocks),  ("embedding_size", embedding_size), ("num_relations", num_rel), ("num_entities", num_entities),
                        ("world_token_count", world_token_count), ("nheads", n_heads), ("dim_feedforward", dim_feedforward), ("pos_enc", pos_enc)])
    model_inits["special_mode"] = {"triple_norm":True}
    u_model_inits = model_inits.copy()
    u_model_inits["world_token_count"] = u_world_token_count
    u_model_inits["n_pre_blocks"] = u_n_pre_blocks
    u_model_inits["n_blocks"] = u_n_blocks
    u_model_inits["n_post_blocks"] = u_n_post_blocks
    u_model_inits["special_mode"] = {"triple_norm":True}
    u_model_args = {"model_inits": u_model_inits, "model_name": "TransformerUpdaterV2"}
    model_inits["special_mode"]["decoder"] = True
    u_model_inits["special_mode"]["decoder"] = True
    u_model_inits["special_mode"]["transformer_layer"] = "TransformerDecoderLayer"
    model_args = {"model_inits": model_inits, "model_name": "TransformerExtractorV2"}
    lr_net_upd = lr_net
    u_model_save_path = "./conway_mutate_models/run.pst"
    run("conway_mutate.py",
        updater_model_vars=u_model_args,
        model_vars=model_args,
        save_path=u_model_save_path,
        save_reset_freq=1000,  test_log_freq=100,
        smoothing_eps=0,
        updater_lr_net=lr_net_upd,
        lr_net=lr_net,
        optimizer_net = "Adam",
        updater_optimizer_net = "Adam",
        n_worlds=256, #batch size
        n_updater_updates=1, #number of times to repeat the same batch in a row
        n_updater_steps=5, #length of world trajectory
        world_resets=10000000,
        world_length=10,
        mutation_counts=(0,9), #range of number of mutations per step
        )



