
import subprocess

import json
import os
from os.path import dirname, abspath
d = dirname(dirname(abspath(__file__)))
os.environ["PYTHONPATH"] = d
import copy

def run(file, **kwargs):
    call = ["python3", file, json.dumps(kwargs)]
    subprocess.run(call, env=os.environ)



if __name__ == "__main__":
    ws_norm = -1
    lr_net = 1e-3
    world_token_count = 8
    embedding_size = 16
    n_blocks = 2
    dim_feedforward = 1024
    n_heads = 2
    num_rel = 3
    u_n_pre_blocks = 2
    u_n_blocks = 2
    u_n_post_blocks = 0
    num_entities = 28
    train_samples_per_step = (0, 75)
    test_samples_per_step = 75
    world_state_size = embedding_size * world_token_count * n_heads
    u_world_token_count = world_token_count
    pos_enc = 1.0
    model_inits = dict([("world_state_size", world_state_size), ("n_blocks", n_blocks),  ("embedding_size", embedding_size), ("num_relations", num_rel), ("num_entities", num_entities),
                        ("world_token_count", world_token_count), ("nheads", n_heads), ("dim_feedforward", dim_feedforward), ("pos_enc", pos_enc)])
    model_inits["special_mode"] = {"triple_norm":True, "triple_embedding": True, "positional_input": 0}
    u_model_inits = copy.deepcopy(model_inits)
    u_model_inits["world_token_count"] = u_world_token_count
    u_model_inits["n_pre_blocks"] = u_n_pre_blocks
    u_model_inits["n_blocks"] = u_n_blocks
    u_model_inits["n_post_blocks"] = u_n_post_blocks
    model_inits["special_mode"]["decoder"] = True
    model_inits["special_mode"]["transformer_layer"] = "IndependentTransformerDecoderLayer"
    u_model_inits["special_mode"]["decoder"] = True
    u_model_inits["special_mode"]["transformer_layer"] = "TransformerDecoderLayer"
    u_model_args = {"model_inits": u_model_inits, "model_name": "TransformerUpdaterV2"}
    model_inits["num_single_entities"] = 0
    model_args = {"model_inits": model_inits, "model_name": "TransformerExtractorV2"}
    lr_net_upd = lr_net
    u_model_save_path = "./mnist_2_models/modern_variable_input.pst"
    schedule = None
    run("mnist_2.py",
        updater_model_vars=u_model_args,
        model_vars=model_args,
        save_path=u_model_save_path,
        save_reset_freq=10000,  test_log_freq=100,
        smoothing_eps=0,
        updater_smoothing_eps=0,
        updater_lr_net=lr_net_upd,
        lr_net=lr_net,
        optimizer_net = "Adam",
        updater_optimizer_net = "Adam",
        n_worlds=128,
        train_samples_per_step=train_samples_per_step,
        n_updater_updates=1,
        n_updater_steps=8,
        test_samples_per_step=75,
        gradient_step_ratio = 0,
        world_resets=10000000,
        lr_scheduler = schedule,
        updater_lr_scheduler = schedule,
        )