from os.path import dirname, abspath
d = dirname(dirname(abspath(__file__)))
import sys
sys.path.append(d)

from utils.datahandler import ConwayDataHandlerMutate
from models.modeltemplate import TransformerUpdater, reinit_model, init_model
from utils import model_functions
from utils.model_args import ConwayDefaults, merge_two_configs, read_params
from argparse import Namespace
import numpy as np
import torch as t
import random
import wandb

loaded = t.load("./conway_mutate_models/run.pst_checkpoint_51")


MERGED_CONFIG = loaded["CONFIG"]

device = "cuda" if t.cuda.is_available() else "cpu"
t.manual_seed(5)
np.random.seed(5)
random.seed(5)

loaded["CONFIG"]["model_vars"]["model_inits"]["special_mode"]["triple_norm"] = True
loaded["CONFIG"]["updater_model_vars"]["model_inits"]["special_mode"]["triple_norm"] = True

model = init_model(loaded["CONFIG"])
model.load_state_dict(loaded['model_state_dict'])
model.eval()

updater = init_model(loaded["CONFIG"], arg_name="updater_model_vars")
updater.load_state_dict(loaded["updater_state_dict"])
updater.eval()




if t.cuda.is_available():
    model.cuda()
    updater.cuda()

datahandler = ConwayDataHandlerMutate({"batch_count": None})

MERGED_CONFIG = ConwayDefaults.add_core_defaults(MERGED_CONFIG)


MERGED_CONFIG = Namespace(**MERGED_CONFIG)
MERGED_CONFIG.n_worlds = 5
MERGED_CONFIG.n_updater_steps = 102
MERGED_CONFIG.world_length = 105
MERGED_CONFIG.test_log_freq = 1
if not hasattr(MERGED_CONFIG, "init_scale"):
    MERGED_CONFIG.init_scale = -1


wandb.init(project="conway_mutate_exploration", config=MERGED_CONFIG)

world_states = model_functions.init_random_world_state(MERGED_CONFIG, datahandler)

for world_reset in range(100):
    datahandler.init_run(MERGED_CONFIG)
    model_functions.train_advance_state(MERGED_CONFIG, datahandler, world_states, updater, model,
                                        train_updater=False, train_extractor=False,
                                        callbacks=[
                                                   model_functions.visualise_model(wandb, world_reset,init=False),
                                                   model_functions.log_train_advance_state(wandb, world_reset)
                                                   ])
