from os.path import dirname, abspath
d = dirname(dirname(abspath(__file__)))
import sys
sys.path.append(d)

from utils.datahandler import MNISTDataHandler
from models.modeltemplate import TransformerUpdater, reinit_model, init_model
from argparse import Namespace
import numpy as np
import torch as t
import random
import torch.nn as nn
from utils import model_functions

device = "cuda" if t.cuda.is_available() else "cpu"
import wandb
loaded = t.load("./mnist_2_models/modern_variable_input.pst_checkpoint_11")


MERGED_CONFIG = loaded["CONFIG"]

device = "cuda" if t.cuda.is_available() else "cpu"
t.manual_seed(5)
np.random.seed(5)
random.seed(5)

model = init_model(loaded["CONFIG"])
model.load_state_dict(loaded['model_state_dict'])
model.eval()

updater = init_model(loaded["CONFIG"], arg_name="updater_model_vars")
updater.load_state_dict(loaded["updater_state_dict"])
updater.eval()


if t.cuda.is_available():
    model.cuda()
    updater.cuda()


MERGED_CONFIG = Namespace(**MERGED_CONFIG)
MERGED_CONFIG.n_worlds = 128
MERGED_CONFIG.test_log_freq = 1
wandb.init(project="mnist_2_exploration", config=MERGED_CONFIG)
datahandler = MNISTDataHandler({"train":False})

loss_func = nn.BCEWithLogitsLoss()
with t.no_grad():
    for world_reset in range(100):
        world_states = model_functions.init_random_world_state(MERGED_CONFIG, datahandler)
        model_functions.train_advance_state(MERGED_CONFIG, datahandler, world_states, updater, model,
                                            train_updater=False, train_extractor=False,
                                            callbacks=[
                                                       model_functions.visualise_model(wandb, world_reset,init=False),
                                                       model_functions.log_train_advance_state(wandb, world_reset)
                                                       ])


wandb.watch(updater)



