import os
import wandb
import argparse

from nesim.losses.nesim_loss import (
    NesimConfig,
)
from nesim.experiments.mnist import MNISTHyperParams, MnistTrainingConfig, MNISTTraining
from nesim.bimt.loss import BIMTConfig
from utils import get_run_name

parser = argparse.ArgumentParser(
    description="Trains a simple 3 layer MLP on the MNIST dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument("--hidden-size", type=int, help="hidden size of the model")
parser.add_argument("--num-epochs", type=int, help="number of epochs of training")
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    help="number of steps after which we apply nesim",
)
parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
args = parser.parse_args()

run_name = get_run_name(
    nesim_config=args.nesim_config,
    hidden_size=args.hidden_size,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
)

checkpoint_dir = f"./checkpoints/mnist/{run_name}"
os.system(f"rm -rf {checkpoint_dir}")
os.system(f"mkdir -p {checkpoint_dir}")
# training hyperparams
hyperparams = MNISTHyperParams(
    lr=1e-3,
    batch_size=64,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=1000000,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)

# setting up nesim stuff
nesim_config = NesimConfig.from_json(args.nesim_config)

# bimt_config = BIMTConfig(
#     layer_names=["4", "7"],
#     distance_between_nearby_layers=0.2,
#     scale=10.0,
#     device="cuda:0",
# )
bimt_config = None

# main config
experiment_config = MnistTrainingConfig(
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    hidden_size=args.hidden_size,
    checkpoint_dir=checkpoint_dir,
    wandb_log=args.wandb_log,
    max_epochs=args.num_epochs,
    data_dir="./data",
    bimt_config=bimt_config,
)

if args.wandb_log:
    wandb.init(project="nesim", name=run_name, config=experiment_config)

experiment = MNISTTraining(config=experiment_config)
experiment.run()
