## Don't run this file directly. Use mnist_exp_runner.py

from utils.datahandler import ConwayDataHandlerMutate
from models.modeltemplate import TransformerUpdater, reinit_model, init_model
from utils.model_args import ConwayDefaults, merge_two_configs, read_params
from argparse import Namespace
import numpy as np
import torch as t
import random
import torch.nn as nn
from utils import model_functions
from matplotlib import pyplot as plt
import argparse

device = "cuda" if t.cuda.is_available() else "cpu"
import wandb
CONFIG = read_params() # any changes to previous config must be specified here

updater = None
if "load_path" in CONFIG:
    print("Using pretrained extractor model")
    loaded = t.load(CONFIG["load_path"])
    MERGED_CONFIG = merge_two_configs(loaded["CONFIG"], CONFIG)
    seed = CONFIG["random_seed"]
    t.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    #model = reinit_model(loaded)
    model = init_model(loaded["CONFIG"])
    model.to(device)
    model.load_state_dict(loaded['model_state_dict'])
    if "updater_model_vars" in loaded["CONFIG"]:
        print("Using pretrained updater model")
        updater = init_model(loaded["CONFIG"], arg_name="updater_model_vars")
        updater.to(device)
        updater.load_state_dict(loaded["updater_state_dict"])
        updater.train()
else:
    MERGED_CONFIG = ConwayDefaults.add_core_defaults(CONFIG)
    seed = CONFIG["random_seed"]
    t.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    model = init_model(CONFIG)
    model.to(device)
model.train()
if updater is None:
    ConwayDefaults.add_updater_defaults(MERGED_CONFIG, MERGED_CONFIG["model_vars"])
    print(MERGED_CONFIG["updater_model_vars"])
    updater = init_model(MERGED_CONFIG, arg_name="updater_model_vars")
    last = [layer for layer in updater.transformer.modules() if isinstance(layer,nn.LayerNorm)][-1]
    last.weight.data *=.08
    updater.train()
    updater.to(device)
print(MERGED_CONFIG)

wandb.init(project="conway_mutate", config=MERGED_CONFIG) # TO DO. Make a correct checkpoint that would include both the model and the new thing
MERGED_CONFIG = Namespace(**MERGED_CONFIG)
wandb.watch(updater)
datahandler = ConwayDataHandlerMutate()

optimizers = {"SGD": t.optim.SGD, "Adam": t.optim.Adam}

train_samples_per_step = MERGED_CONFIG.train_samples_per_step

optimizer_updater = optimizers[MERGED_CONFIG.optimizer_net](updater.parameters(), lr=MERGED_CONFIG.updater_lr_net)
optimizer_extractor = optimizers[MERGED_CONFIG.updater_optimizer_net](model.parameters(), lr=MERGED_CONFIG.lr_net)
# loss_func = nn.BCEWithLogitsLoss(pos_weight=t.FloatTensor([1/0.15]).to(device))
loss_func = nn.BCEWithLogitsLoss()
# Optimize world states and model parameters for the provided batch data
for l in range(MERGED_CONFIG.world_resets):

    world_states = model_functions.init_random_world_state(MERGED_CONFIG, datahandler)
    results = model_functions.train_advance_state(MERGED_CONFIG, datahandler, world_states, updater, model, loss_func,
                                                  train_extractor=optimizer_extractor,
                                                  train_updater=optimizer_updater,
                                                  callbacks=[
                                                            model_functions.log_train_advance_state(wandb, l),
                                                             model_functions.visualise_model(wandb, l, init=False, gif_mode=False),
                                                             model_functions.save_joint_callback(l)])
model_functions.save_joint_model(MERGED_CONFIG, model, updater, checkpoint=l+1)
