import subprocess

import json
import os
from os.path import dirname, abspath
d = dirname(dirname(abspath(__file__)))
os.environ["PYTHONPATH"] = d
import copy

def run(file, **kwargs):
    call = ["python3", file, json.dumps(kwargs)]
    subprocess.run(call, env=os.environ)



if __name__ == "__main__":
    ws_norm = -1 #if >0, restrict the magnitude of the world state to unit length under this norm
    lr_net = 3e-4
    world_token_count = 32
    embedding_size = 16
    n_blocks = 4
    dim_feedforward = 2048
    n_heads = 4
    num_rel = 1
    u_n_pre_blocks = 0 #number of layers without cross attention with the input, before other layers
    u_n_blocks = 4
    u_n_post_blocks = 0 #number of layers without cross attention with the input, after other layers
    num_entities, n_updater_steps = 32, 8
    world_state_size = embedding_size * world_token_count * n_heads
    u_world_token_count = world_token_count
    model_inits = dict([("world_state_size", world_state_size), ("n_blocks", n_blocks),  ("embedding_size", embedding_size), ("num_relations", num_rel), ("num_entities", num_entities),
                        ("world_token_count", world_token_count), ("nheads", n_heads), ("dim_feedforward", dim_feedforward)])
    model_inits["special_mode"] = {"triple_norm":True, "triple_embedding": True}
    u_model_inits = copy.deepcopy(model_inits)
    u_model_inits["world_token_count"] = u_world_token_count
    u_model_inits["n_pre_blocks"] = u_n_pre_blocks
    u_model_inits["n_blocks"] = u_n_blocks
    u_model_inits["n_post_blocks"] = u_n_post_blocks
    model_inits["special_mode"]["decoder"] = True
    model_inits["special_mode"]["transformer_layer"] = "IndependentTransformerDecoderLayer"
    u_model_inits["special_mode"]["decoder"] = True
    u_model_inits["special_mode"]["transformer_layer"] = "TransformerDecoderLayer"
    model_inits["world_offset"] = {"max_len":world_token_count, "dropout":0.1, "scale":1, "init":"sine"}
    u_model_inits["world_offset"] = {"max_len":u_world_token_count, "dropout":0.1, "scale":1, "init":"sine"}
    u_model_inits["special_mode"]["subtract_world_offset"] = True #subtracts the world offset at the end of the updater, since the updater adds the world offset at every step at the beginning
    model_inits["input_encoding"] = {"max_len":num_entities ** 2, "dropout":0.1, "scale":1, "init":"sine"}
    u_model_inits["input_encoding"] = {"max_len":num_entities ** 2, "dropout":0.1, "scale":1, "init":"sine"}
    u_model_args = {"model_inits": u_model_inits, "model_name": "TransformerUpdater"}
    model_inits["num_single_entities"] = 1
    model_args = {"model_inits": model_inits, "model_name": "TransformerExtractor"}
    lr_net_upd = lr_net
    u_model_save_path = "./pathfinder_models/run.pst"
    n_worlds = 64 #batch size
    schedule = ("linear_warmup", [], {})
    #threshold sample length is the number of logged steps that are averaged and compared to the threshold.
    curriculum = {"threshold_sample_length": 50, "threshold":.95, "metric": "accIndirect"}
    run("pathfinder_2.py",
        updater_model_vars=u_model_args,
        model_vars=model_args,
        save_path=u_model_save_path,
        save_reset_freq=10000,  test_log_freq=100,
        smoothing_eps=0,
        updater_lr_net=lr_net_upd,
        lr_net=lr_net,
        optimizer_net = "Adam",
        updater_optimizer_net = "Adam",
        n_worlds=n_worlds,
        n_updater_updates=1, n_updater_steps=n_updater_steps,
        gradient_step_ratio = 0,
        world_resets=100000000,
        class_loss_scale = 1,
        detach_class_loss=False,
        sampling_schedule="uniform", #sample recall queries with equal probability from all previous steps
        lr_scheduler=schedule,
        updater_lr_scheduler=schedule,
        curriculum=curriculum,
        vis_batch_size=0,
        random_seed=1,
        )



