import wandb
import os
from nesim.losses.nesim_loss import (
    NesimConfig,
)
import argparse
from nesim.experiments.cifar100 import (
    Cifar100Training,
    Cifar100TrainingConfig,
    Cifar100HyperParams,
)
from nesim.bimt.loss import BIMTConfig
from utils import get_run_name

parser = argparse.ArgumentParser(
    description="Trains a resnet18 on the CIFAR100 dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    help="number of steps after which we apply nesim",
)
parser.add_argument("--bimt-scale", type=float, help="scale of bimt loss", default=None)
parser.add_argument(
    "--pretrained",
    help="True will init resnet18 with imagenet weights",
    action="store_true",
)
parser.add_argument(
    "--no-pretrained",
    help="Will init resnet18 with random weights",
    action="store_true",
)
parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
args = parser.parse_args()

assert (
    args.pretrained != args.no_pretrained
), "Any one of them should be True. Both should not be False"

if args.pretrained == True and args.no_pretrained == False:
    pretrained = True
elif args.pretrained == False and args.no_pretrained == True:
    pretrained = False


run_name = get_run_name(
    nesim_config=args.nesim_config,
    pretrained=args.pretrained,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
    bimt_scale=args.bimt_scale,
)
checkpoint_dir = f"./checkpoints/cifar100/{run_name}"
os.system(f"rm -rf {checkpoint_dir}")
os.system(f"mkdir -p {checkpoint_dir}")

hyperparams = Cifar100HyperParams(
    lr=5e-4,
    batch_size=128,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=100000000000000000000,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)

# setting up nesim stuff
nesim_config = NesimConfig.from_json(args.nesim_config)

# layer_names = [
#     "layer4.0.conv1",
#     "layer4.0.conv2",
# ]
# bimt_config = BIMTConfig(
#     layer_names=layer_names,
#     distance_between_nearby_layers=0.2,
#     scale=args.bimt_scale,
#     device="cuda:0",
# )

experiment_config = Cifar100TrainingConfig(
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    wandb_log=args.wandb_log,
    weights="DEFAULT" if pretrained is True else None,
    checkpoint_dir=checkpoint_dir,
    max_epochs=5,
    bimt_config=None,
)

if args.wandb_log:
    wandb.init(project="nesim-cifar100", name=run_name, config=experiment_config)

experiment = Cifar100Training(config=experiment_config)
experiment.run()
