from nesim.vis.weights_pca import WeightsPCAViewer
from nesim.vis.video import generate_video
import torch.nn as nn
import os

import os
import argparse

from nesim.losses.nesim_loss import (
    NesimConfig,
)
from nesim.experiments.cifar100 import Cifar100TrainingConfig, Cifar100HyperParams
from utils import get_run_name
from nesim.bimt.loss import BIMTConfig, BIMTLoss
import torchvision.models as models

videos_save_dir = "./videos/pca"

parser = argparse.ArgumentParser(
    description="Trains a resnet18 on the tiny-imagenet dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    help="number of steps after which we apply nesim",
)
parser.add_argument("--bimt-scale", type=float, help="scale of bimt loss", default=None)
parser.add_argument(
    "--pretrained",
    help="True will init resnet18 with imagenet weights",
    action="store_true",
)
parser.add_argument(
    "--no-pretrained",
    help="Will init resnet18 with random weights",
    action="store_true",
)
parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
args = parser.parse_args()

assert (
    args.pretrained != args.no_pretrained
), "Any one of them should be True. Both should not be False"

if args.pretrained == True and args.no_pretrained == False:
    pretrained = True
elif args.pretrained == False and args.no_pretrained == True:
    pretrained = False


run_name = get_run_name(
    nesim_config=args.nesim_config,
    pretrained=args.pretrained,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
    bimt_scale=args.bimt_scale,
)

checkpoint_dir = f"./checkpoints/cifar100/{run_name}"

hyperparams = Cifar100HyperParams(
    lr=5e-4,
    batch_size=256,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=20,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)

# setting up nesim stuff
nesim_config = NesimConfig.from_json(args.nesim_config)

layer_names = [
    "layer4.0.conv1",
    "layer4.0.conv2",
]

bimt_config = BIMTConfig(
    layer_names=layer_names,
    distance_between_nearby_layers=0.2,
    scale=0.1,
    device="cuda:0",
)

experiment_config = Cifar100TrainingConfig(
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    wandb_log=False,
    weights="DEFAULT" if pretrained is True else None,
    checkpoint_dir=checkpoint_dir,
    max_epochs=20,
    bimt_config=bimt_config,
)
model = models.resnet18(weights=experiment_config.weights)
model.fc = nn.Linear(512, 100)
## to avoid state dict loading errors
model = BIMTLoss.from_config(config=bimt_config).init_modules_for_training(model)


checkpoint_filenames = [
    os.path.join(checkpoint_dir, "all", f"train_step_idx_{train_step_idx}.pth")
    for train_step_idx in range(0, 4000, hyperparams.save_checkpoint_every_n_steps)
]
for layer_name in layer_names:
    filename = os.path.join(videos_save_dir, f"{run_name}_layer_name_{layer_name}.mp4")
    weights_pca = WeightsPCAViewer(
        model=model,
        checkpoint_filenames=checkpoint_filenames,
        layer_name=layer_name,
        device="cuda:0",
        scale_by_magnitude=True,
    )

    print(f"writing: {len(weights_pca)} frames to: {filename}")
    generate_video(
        list_of_pil_images=[weights_pca[idx] for idx in range(len(weights_pca))],
        framerate=10,
        filename=filename,
        ## width, height
        size=(1024, 512),
    )
