from os.path import dirname, abspath
d = dirname(dirname(abspath(__file__)))
import sys
sys.path.append(d)

from utils.datahandler import PathfinderDataHandler
from models.modeltemplatev2 import reinit_model, init_model
from argparse import Namespace
import numpy as np
import torch as t
import random
import torch.nn as nn
from utils import model_functions

device = "cuda" if t.cuda.is_available() else "cpu"
import wandb
load_path = "./pathfinder_models/run.pst_checkpoint_11"
loaded = t.load(load_path)

MERGED_CONFIG = loaded["CONFIG"]

device = "cuda" if t.cuda.is_available() else "cpu"
t.manual_seed(5)
np.random.seed(5)
random.seed(5)

loaded["CONFIG"]["model_vars"]["model_inits"]["special_mode"]["triple_norm"] = True
loaded["CONFIG"]["updater_model_vars"]["model_inits"]["special_mode"]["triple_norm"] = True


model = init_model(loaded["CONFIG"])
model.load_state_dict(loaded['model_state_dict'])
model.eval()

updater = init_model(loaded["CONFIG"], arg_name="updater_model_vars")
updater.load_state_dict(loaded["updater_state_dict"])
updater.eval()
print(sum([x.numel() for x in model.parameters()]))
print(sum([x.numel() for x in updater.parameters()]))

if t.cuda.is_available():
    model.cuda()
    updater.cuda()

loss_func = nn.BCEWithLogitsLoss()
MERGED_CONFIG = Namespace(**MERGED_CONFIG)
MERGED_CONFIG.n_worlds = 50
MERGED_CONFIG.test_log_freq = 1
if not hasattr(MERGED_CONFIG, "sampling_schedule"):
    MERGED_CONFIG.sampling_schedule = "uniform"
MERGED_CONFIG.load_path = load_path
wandb.init(project="pathfinder_exploration", config=MERGED_CONFIG)

datahandler = PathfinderDataHandler({"split":"val", "grid_size":MERGED_CONFIG.model_vars["model_inits"]["num_entities"],
                                     "shuffle": False})
with t.no_grad():
    from collections import defaultdict
    logs = defaultdict(lambda:[])
    assert 20000 % MERGED_CONFIG.n_worlds == 0, "batch size must be a divisor of 20000 for dataset metrics"
    for world_reset in range(20000 // MERGED_CONFIG.n_worlds):
        world_states = model_functions.init_random_world_state(MERGED_CONFIG, datahandler)
        model_functions.train_advance_state(MERGED_CONFIG, datahandler, world_states, updater, model,
                                            train_updater=False, train_extractor=False,
                                            callbacks=[
                                                       model_functions.classification_loss_callback(loss_func),
                                                       model_functions.log_train_advance_state(logs, world_reset),
                                                       model_functions.max_step_callback(world_reset, 11)

                                                       ])
    model_functions.log_dict_mean(logs, MERGED_CONFIG, wandb)


wandb.watch(updater)



