from nesim.utils.feature_vis.generator import ConvLayerFeaturevisGenerator
import torch.nn as nn
import os
from nesim.utils.getting_modules import get_module_by_name
import os
import argparse
import torch
from nesim.losses.nesim_loss import (
    NesimConfig,
)
from nesim.experiments.cifar100 import Cifar100TrainingConfig, Cifar100HyperParams
from utils import get_run_name
from nesim.bimt.loss import BIMTConfig, BIMTLoss
import torchvision.models as models
from nesim.vis.image_grid import make_grid_from_list_of_images
from PIL import Image
from nesim.utils.grid_size import find_rectangle_dimensions

parser = argparse.ArgumentParser(
    description="Trains a resnet18 on the tiny-imagenet dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    help="number of steps after which we apply nesim",
)
parser.add_argument("--bimt-scale", type=float, help="scale of bimt loss", default=None)
parser.add_argument(
    "--pretrained",
    help="True will init resnet18 with imagenet weights",
    action="store_true",
)
parser.add_argument(
    "--no-pretrained",
    help="Will init resnet18 with random weights",
    action="store_true",
)
parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
args = parser.parse_args()

assert (
    args.pretrained != args.no_pretrained
), "Any one of them should be True. Both should not be False"

if args.pretrained == True and args.no_pretrained == False:
    pretrained = True
elif args.pretrained == False and args.no_pretrained == True:
    pretrained = False


run_name = get_run_name(
    nesim_config=args.nesim_config,
    pretrained=args.pretrained,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
    bimt_scale=args.bimt_scale,
)

checkpoint_dir = f"./checkpoints/cifar100/{run_name}"
neuron_atlas_folder = f"./neuron_atlas/"
hyperparams = Cifar100HyperParams(
    lr=5e-4,
    batch_size=256,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=300,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)

# setting up nesim stuff
nesim_config = NesimConfig.from_json(args.nesim_config)

layer_names = [
    "layer4.0.conv1",
    "layer4.0.conv2",
]

bimt_config = BIMTConfig(
    layer_names=layer_names,
    distance_between_nearby_layers=0.2,
    scale=0.1,
    device="cuda:0",
)

experiment_config = Cifar100TrainingConfig(
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    wandb_log=False,
    weights="DEFAULT" if pretrained is True else None,
    checkpoint_dir=checkpoint_dir,
    max_epochs=20,
    bimt_config=bimt_config,
)
model = models.resnet18(weights=experiment_config.weights)
model.fc = nn.Linear(512, 100)

sd_fixed = {}
sd = torch.load(os.path.join(checkpoint_dir, "best", "best_model.ckpt"))["state_dict"]

for key in sd:
    sd_fixed[key.replace("model.", "").replace(".layer.", ".")] = sd[key]

model.load_state_dict(sd_fixed)

render_kwargs = dict(
    scale_max=1.2, scale_min=0.8, iters=220, lr=3e-3, grad_clip=1.0, rotate_degrees=10
)

for layer_name in layer_names:
    generator = ConvLayerFeaturevisGenerator(
        model=model,
        render_kwargs=render_kwargs,
        target_layer=get_module_by_name(module=model, name=layer_name),
        batch_size=32,
        width=32,
        height=32,
        standard_deviation=0.01,
    )

    folder_name = os.path.join(
        neuron_atlas_folder, f"{run_name}_layer_name_{layer_name}"
    )
    os.system(f"rm -rf {folder_name}")
    os.system(f"mkdir -p {folder_name}")

    filenames = generator.generate(output_folder=folder_name)
    size = find_rectangle_dimensions(area=len(filenames))
    print("SIZE", size)
    grid_image = make_grid_from_list_of_images(
        images=[Image.open(filename) for filename in filenames],
        height=size.height,
        width=size.width,
    )

    grid_filename = os.path.join(
        neuron_atlas_folder, "outputs", f"{run_name}_layer_name_{layer_name}.png"
    )
    print(f"saving: {grid_filename}")
    grid_image.save(grid_filename)
