from nesim.utils.feature_vis.generator import (
    ConvLayerFeaturevisGenerator,
    LinearLayerFeaturevisGenerator,
)
from neuro.models.neuro_model.config import NeuroModelConfig

from nesim.utils.json_stuff import load_json_as_dict
from neuro.models.neuro_model.model import NeuroModel
from utils import make_convmapper_config_from_size_sequences
from PIL import Image
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.vis.image_grid import make_grid_from_list_of_images
import os
import argparse

parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--start", type=int, help="start index", default=0)
parser.add_argument("--stop", type=int, help="stop index", default=11970)
parser.add_argument("--step", type=int, help="step size", default=1000)
parser.add_argument("--config", type=str, help="config file")
parser.add_argument("--checkpoint", type=str, help="config file")

args = parser.parse_args()

config = load_json_as_dict(filename=args.config)
assert os.path.exists(args.checkpoint)

conv_mapper_config = make_convmapper_config_from_size_sequences(
    conv_layer_size_sequence=config["model"]["conv_layer_size_sequence"],
    linear_layer_size_sequence=config["model"]["linear_layer_size_sequence"],
    reduce_fn=config["model"]["conv_mapper_reduce_fn"],
    activation=config["model"]["conv_mapper_activation"],
    conv_layer_kernel_size=config["model"]["conv_mapper_kernel_size"],
)

neuro_model_config = NeuroModelConfig(
    brain_response_predictor_config=conv_mapper_config,
    image_encoder=config["model"]["neuro_model_config_image_encoder"],
    hook_layer_name=config["model"]["intermediate_layer_name"],
    reduce_fn=config["model"]["neuro_model_reduce_fn"],
)


model = NeuroModel(config=neuro_model_config, device=config["model"]["device"])


render_kwargs = dict(
    scale_max=1.0,
    scale_min=1.0,
    iters=120,
    lr=3e-3,
    grad_clip=1.0,
    rotate_degrees=10,
)

size = find_rectangle_dimensions(
    area=model.brain_response_predictor.conv_layers[-1].out_channels
)

feature_vis_generator = ConvLayerFeaturevisGenerator(
    model=model,
    target_layer=model.brain_response_predictor.conv_layers[-1],
    batch_size=64,
    render_kwargs=render_kwargs,
    quiet=True,
)

# size = find_rectangle_dimensions(
#     area = model.brain_response_predictor.linear_mapper.model[-3].out_features
# )
# feature_vis_generator = LinearLayerFeaturevisGenerator(
#     model=model,
#     target_layer=model.brain_response_predictor.linear_mapper.model[-3],
#     batch_size = 32,
#     render_kwargs=render_kwargs,
# )

"""
neuron_atlas_video:
    atlases:
        - checkpoint_filename_1.jpg
        - checkpoint_filename_2.jpg

    checkpoint_featurevis:
        - checkpoint_filename_1/
            - 0.png
            - 1.png
        - checkpoint_filename_2/
            - 0.png
            - 1.png
"""

vis_root = "./neuron_atlas_video/checkpoint_featurevis"
frames_root = "./neuron_atlas_video/atlases"
os.system(f"mkdir -p {vis_root}")
os.system(f"mkdir -p {frames_root}")


checkpoint_filenames = [
    f"train_step_idx_{i}.pth" for i in range(args.start, args.stop, args.step)
]
print(f"start: {args.start} stop: {args.stop} step: {args.step}")
print(f"Will visualize {len(checkpoint_filenames)} checkpoints")
checkpoint_paths = [os.path.join("./step_checkpoints", f) for f in checkpoint_filenames]

frame_filenames = []
for filename, path in zip(checkpoint_filenames, checkpoint_paths):
    assert os.path.exists(path)

    model.brain_response_predictor.load(path)

    images_folder = os.path.join(vis_root, filename)
    os.system(f"mkdir -p {images_folder}")
    filenames = feature_vis_generator.generate(output_folder=images_folder)

    grid_image = make_grid_from_list_of_images(
        images=[Image.open(filename) for filename in filenames],
        height=size.height,
        width=size.width,
    )
    frame_filename = os.path.join(frames_root, f"{filename}.jpg")
    grid_image.save(frame_filename)
    print(f"saved {frame_filename}")
    frame_filenames.append(frame_filename)

print("DONE")
