import wandb
import os
from nesim.losses.nesim_loss import (
    NesimConfig,
)
import argparse
from nesim.experiments.cifar100 import (
    Cifar100Training,
    Cifar100TrainingConfig,
    Cifar100HyperParams,
)
from nesim.bimt.loss import BIMTConfig
from utils import get_run_name
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
from nesim.utils.folder import get_filenames_in_a_folder


parser = argparse.ArgumentParser(
    description="Trains a resnet18 on the CIFAR100 dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    help="number of steps after which we apply nesim",
)
parser.add_argument("--num-epochs", type=int, help="number of epochs of training")

parser.add_argument("--bimt-scale", type=float, help="scale of bimt loss", default=None)

parser.add_argument(
    "--load-weights-path",
    type=str,
    help="""Will init resnet18 with weights. 
    1. Set DEFAULT for imagenet weights
    2. None for random weights
    3. or set it as path to a lightning checkpoint""",
)
parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
parser.add_argument("--wandb-fresh", action="store_true", help="start fresh wandb run")

args = parser.parse_args()

run_name = get_run_name(
    nesim_config=args.nesim_config,
    pretrained=False,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
    bimt_scale=args.bimt_scale,
)
checkpoint_dir = f"./checkpoints/cifar100/{run_name}"
os.system(f"rm -rf {checkpoint_dir}")
os.system(f"mkdir -p {checkpoint_dir}")

hyperparams = Cifar100HyperParams(
    lr=5e-4,
    batch_size=128,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=391,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)

# setting up nesim stuff
nesim_config = NesimConfig.from_json(args.nesim_config)

# layer_names = [
#     "layer4.0.conv1",
#     "layer4.0.conv2",
# ]
# bimt_config = BIMTConfig(
#     layer_names=layer_names,
#     distance_between_nearby_layers=0.2,
#     scale=args.bimt_scale,
#     device="cuda:0",
# )

experiment_config = Cifar100TrainingConfig(
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    wandb_log=args.wandb_log,
    weights=args.load_weights_path,
    checkpoint_dir=checkpoint_dir,
    max_epochs=args.num_epochs,
    bimt_config=None,
)

resume_from_data = load_json_as_dict("./resume_from.json")

if args.wandb_log:
    if args.wandb_fresh == True:
        wandb.init(project="nesim-cifar100", name=run_name, config=experiment_config)
    else:
        wandb.init(
            project="nesim-cifar100",
            id=resume_from_data["wandb_run_id"],
            resume="allow",
        )

experiment = Cifar100Training(config=experiment_config)
experiment.run()

all_models_saved_so_far = get_filenames_in_a_folder(checkpoint_dir + "/all/")
model_will_be_saved_at = sorted(all_models_saved_so_far)[-1]

resume_from_data["model_path"] = model_will_be_saved_at
dict_to_json(resume_from_data, "./resume_from.json")

# for baseline: python3 train.py --nesim-config ./nesim_configs/baseline.json --load-weights-path None --nesim-apply-after-n-steps 1 --num-epochs 4 --wandb-log --wandb-fresh
