import wandb
import os
from nesim.losses.nesim_loss import (
    NesimConfig,
)
import argparse
from nesim.experiments.imagenet import (
    ImageNetTraining,
    ImageNetTrainingConfig,
    ImageNetHyperParams,
)
from nesim.losses.cross_layer_correlation.loss import (
    CrossLayerCorrelationLossConfig,
)
from nesim.losses.cross_layer_correlation.loss import (
    SinglePairConfig,
    CrossLayerCorrelationLossConfig,
)
from utils import get_run_name
from fastapi.encoders import jsonable_encoder

parser = argparse.ArgumentParser(
    description="Trains a resnet18 on the imagenet dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument(
    "--cross-layer-correlation-loss-config",
    type=str,
    help="Path to the cross-layer-correlation-loss config json file",
    default=None,
)
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    help="number of steps after which we apply nesim",
)
parser.add_argument(
    "--num-epochs", type=int, help="number of epochs of training", default=50
)
parser.add_argument("--bimt-scale", type=float, help="scale of bimt loss", default=None)
parser.add_argument(
    "--pretrained",
    help="True will init resnet18 with imagenet weights",
    action="store_true",
)
parser.add_argument(
    "--no-pretrained",
    help="Will init resnet18 with random weights",
    action="store_true",
)
parser.add_argument(
    "--resume-from-checkpoint", help="will load these weights", default=None, type=str
)
parser.add_argument(
    "--apply-sorted-weights-init-filename",
    default=None,
    help="Will sort weights in layer names specified in json file",
)

parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
args = parser.parse_args()

assert (
    args.pretrained != args.no_pretrained
), "Any one of them should be True. Both should not be False"

if args.pretrained == True and args.no_pretrained == False:
    pretrained = True
elif args.pretrained == False and args.no_pretrained == True:
    pretrained = False

# setting up cross layer wiring cost stuff
if args.cross_layer_correlation_loss_config is not None:
    print(
        f"Applying CrossLayerCorrelationLossConfig: {args.cross_layer_correlation_loss_config}"
    )
    cross_layer_correlation_loss_config = CrossLayerCorrelationLossConfig.from_json(
        args.cross_layer_correlation_loss_config
    )
else:
    print(f"Not applying any cross layer correlation loss")
    cross_layer_correlation_loss_config = None


run_name = get_run_name(
    nesim_config=args.nesim_config,
    cross_layer_correlation_loss_config=args.cross_layer_correlation_loss_config,
    pretrained=pretrained,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
    bimt_scale=args.bimt_scale,
    apply_sorted_weights_init_filename=args.apply_sorted_weights_init_filename,
)
checkpoint_dir = f"./checkpoints/imagenet/{run_name}"
os.system(f"mkdir -p {checkpoint_dir}")

"""
original resnet paper: XXXX

Pytorch recipe:
[x] - epochs: 90 -> made it 100
[x] - batch_size: 1024
[x] - weight_decay: 1e-4
[x] - scheduler: steplr decrease lr every 30 epochs by a factor of 0.1
[x] - minimum lr: 0.
[x] - optimizer: SGD with momentum 0.9
[x] - initial_learning_rate: 0.1
[x] - Use ClassificationPresetTrain and ClassificationPresetEval for image preprocessing
"""
hyperparams = ImageNetHyperParams(
    lr=0.1,
    batch_size=1024,
    weight_decay=1e-4,
    momentum=0.9,
    save_checkpoint_every_n_steps=1000,
    scheduler_step_size=30,
    scheduler_gamma=0.1,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)

# setting up nesim stuff
nesim_config = NesimConfig.from_json(args.nesim_config)

experiment_config = ImageNetTrainingConfig(
    model_name="resnet18",
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    wandb_log=args.wandb_log,
    weights="DEFAULT" if pretrained is True else None,
    checkpoint_dir=checkpoint_dir,
    max_epochs=args.num_epochs,
    bimt_config=None,
    cross_layer_correlation_loss_config=cross_layer_correlation_loss_config,
    skip_initial_validation_run=True,
    apply_sorted_weights_init_filename=args.apply_sorted_weights_init_filename,
    cache_dir="/om2/user/mayukh09/datasets/imagenet_converted",  ## openmind
    # cache_dir = '/research/datasets/imagenet_converted/',          ## XXXX-2@barlow
    resume_from_checkpoint=args.resume_from_checkpoint,  ## this overrides the weights arg
)

if args.wandb_log:
    wandb.init(
        project="nesim-imagenet-resnet18",  ## serious run
        # project = 'nesim', ## dummy run
        name=run_name,
        config=jsonable_encoder(experiment_config),
    )

experiment = ImageNetTraining(config=experiment_config)
experiment.run()
