## Don't run this file directly. Use mnist_exp_runner.py

from utils.datahandler import PathfinderDataHandler
from models.modeltemplatev2 import reinit_model, init_model
from utils.model_args import PathfinderDefaults, merge_two_configs, read_params
from argparse import Namespace
import numpy as np
import torch as t
import random
import torch.nn as nn
from utils import model_functions
from matplotlib import pyplot as plt
import argparse

device = "cuda" if t.cuda.is_available() else "cpu"
import wandb
CONFIG = read_params() # any changes to previous config must be specified here



updater = None
if "load_path" in CONFIG:
    print("Using pretrained extractor model")
    loaded = t.load(CONFIG["load_path"])
    MERGED_CONFIG = merge_two_configs(loaded["CONFIG"], CONFIG)
    seed = CONFIG["random_seed"]
    t.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    #model = reinit_model(loaded)
    model = init_model(loaded["CONFIG"])
    model.to(device)
    model.load_state_dict(loaded['model_state_dict'])
    if "updater_model_vars" in loaded["CONFIG"]:
        print("Using pretrained updater model")
        updater = init_model(loaded["CONFIG"], arg_name="updater_model_vars")
        updater.to(device)
        updater.load_state_dict(loaded["updater_state_dict"])
        updater.train()
else:
    MERGED_CONFIG = PathfinderDefaults.add_core_defaults(CONFIG)
    loaded = None
    seed = CONFIG["random_seed"]
    t.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    model = init_model(CONFIG)
    model.to(device)

if updater is None:
    PathfinderDefaults.add_updater_defaults(MERGED_CONFIG, MERGED_CONFIG["model_vars"])
    print(MERGED_CONFIG["updater_model_vars"])
    updater = init_model(MERGED_CONFIG, arg_name="updater_model_vars")
    last = [layer for layer in updater.transformer.modules() if isinstance(layer, nn.LayerNorm)][-1]
    last.weight.data *=.08
    updater.to(device)

model.train()
updater.train()

print([x.numel() for x in model.parameters()])
print([x.numel() for x in updater.parameters()])
print(sum([x.numel() for x in model.parameters()]))
print(sum([x.numel() for x in updater.parameters()]))

print(MERGED_CONFIG)

wandb.init(project="pathfinder_2", config=MERGED_CONFIG) # TO DO. Make a correct checkpoint that would include both the model and the new thing
MERGED_CONFIG = Namespace(**MERGED_CONFIG)
wandb.watch(updater)
datahandler = PathfinderDataHandler({"grid_size": MERGED_CONFIG.model_vars["model_inits"]["num_entities"]})

optimizers = {"SGD": t.optim.SGD, "Adam": t.optim.Adam}

train_samples_per_step = MERGED_CONFIG.train_samples_per_step

optimizer_extractor = optimizers[MERGED_CONFIG.updater_optimizer_net](model.parameters(), lr=MERGED_CONFIG.lr_net)
optimizer_updater = optimizers[MERGED_CONFIG.optimizer_net](updater.parameters(), lr=MERGED_CONFIG.updater_lr_net)

scheduler_callback = model_functions.get_scheduler(MERGED_CONFIG, optimizer_extractor, optimizer_updater)


if loaded is not None:
    if "modelopt" in loaded:
        model_functions.set_opt_state(optimizer_extractor, model.parameters(), loaded["modelopt"])
    if "updateropt" in loaded:
        model_functions.set_opt_state(optimizer_updater, updater.parameters(), loaded["updateropt"])

loss_func = nn.BCEWithLogitsLoss(pos_weight=t.FloatTensor([MERGED_CONFIG.pos_weight]).to(device))
loss_func_classif = nn.BCEWithLogitsLoss()
# Optimize world states and model parameters for the provided batch data

if MERGED_CONFIG.curriculum:
    MERGED_CONFIG.max_steps = max(MERGED_CONFIG.curriculum.get("min_steps", 2), MERGED_CONFIG.gradient_step_ratio)
    scores = []

for l in range(MERGED_CONFIG.world_resets):
    callbacks=[
        model_functions.classification_loss_callback(loss_func_classif),
        model_functions.save_joint_callback(l),
        scheduler_callback(l)
    ]
    if MERGED_CONFIG.vis_batch_size:
        callbacks += [model_functions.visualise_model(wandb, l, init=False,
                                                       gif_mode=False, labeled=True,
                                                       progressive=True)]

    if MERGED_CONFIG.curriculum:
        log_callback, metrics = model_functions.log_train_advance_state(wandb, l, True)
        callbacks += [log_callback, model_functions.max_step_callback(l, MERGED_CONFIG.max_steps, wandb)]
    else:
        callbacks += [model_functions.log_train_advance_state(wandb, l)]

    world_states = model_functions.init_random_world_state(MERGED_CONFIG, datahandler)
    results = model_functions.train_advance_state(MERGED_CONFIG, datahandler, world_states, updater, model, loss_func,
                                                  train_extractor=optimizer_extractor,
                                                  train_updater=optimizer_updater,
                                                  callbacks = callbacks)
    if MERGED_CONFIG.curriculum and metrics and MERGED_CONFIG.max_steps < MERGED_CONFIG.n_updater_steps:
        last_step = MERGED_CONFIG.max_steps - 1
        scores += [metrics[last_step][MERGED_CONFIG.curriculum.get("metric", "accIndirect")]]
        sample_length = MERGED_CONFIG.curriculum.get("threshold_sample_length", 50)
        if len(scores) > sample_length:
            scores = scores[1:]
            if sum(scores) / sample_length > MERGED_CONFIG.curriculum.get("threshold", .9):
                MERGED_CONFIG.max_steps += max(1, MERGED_CONFIG.gradient_step_ratio)
                scores = []




model_functions.save_joint_model(MERGED_CONFIG, model, updater,
                                 optimizers=[optimizer_extractor, optimizer_updater], checkpoint=l+1)