from nesim.utils.feature_vis.generator import (
    ConvLayerFeaturevisGenerator,
    LinearLayerFeaturevisGenerator,
)
from neuro.models.neuro_model.config import NeuroModelConfig

from nesim.utils.json_stuff import load_json_as_dict
from neuro.models.neuro_model.model import NeuroModel
from neuro.utils.model_building import make_convmapper_config_from_size_sequences
from PIL import Image
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.vis.image_grid import make_grid_from_list_of_images
import os
import argparse

parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--config", type=str, help="config file")
parser.add_argument("--checkpoint", type=str, help="checkpoint file")

args = parser.parse_args()
config = load_json_as_dict(filename=args.config)
assert os.path.exists(args.checkpoint)


conv_mapper_config = make_convmapper_config_from_size_sequences(
    conv_layer_size_sequence=config["model"]["conv_layer_size_sequence"],
    linear_layer_size_sequence=config["model"]["linear_layer_size_sequence"],
    reduce_fn=config["model"]["conv_mapper_reduce_fn"],
    activation=config["model"]["conv_mapper_activation"],
    conv_layer_kernel_size=config["model"]["conv_mapper_kernel_size"],
)

neuro_model_config = NeuroModelConfig(
    brain_response_predictor_config=conv_mapper_config,
    image_encoder=config["model"]["neuro_model_config_image_encoder"],
    hook_layer_name=config["model"]["intermediate_layer_name"],
    reduce_fn=config["model"]["neuro_model_reduce_fn"],
)


model = NeuroModel(config=neuro_model_config, device=config["model"]["device"])
model.brain_response_predictor.load(args.checkpoint)

render_kwargs = dict(
    scale_max=1.0, scale_min=1.0, iters=120, lr=1e-3, grad_clip=1.0, rotate_degrees=5
)
os.system(f'rm  {config["neuron_atlas"]["output_folder"]}/*.png')

size = find_rectangle_dimensions(
    area=model.brain_response_predictor.conv_layers[-1].out_channels
)

feature_vis_generator = ConvLayerFeaturevisGenerator(
    model=model,
    target_layer=model.brain_response_predictor.conv_layers[-1],
    batch_size=64,
    render_kwargs=render_kwargs,
)

# size = find_rectangle_dimensions(
#     area = model.brain_response_predictor.linear_mapper.model[-3].out_features
# )
# feature_vis_generator = LinearLayerFeaturevisGenerator(
#     model=model,
#     target_layer=model.brain_response_predictor.linear_mapper.model[-3],
#     batch_size = 32,
#     render_kwargs=render_kwargs,
# )

filenames = feature_vis_generator.generate(
    output_folder=config["neuron_atlas"]["output_folder"]
)
filenames = [
    os.path.join(config["neuron_atlas"]["output_folder"], f"{idx}.png")
    for idx in range(256)
]
print("SIZE", size)
grid_image = make_grid_from_list_of_images(
    images=[Image.open(filename) for filename in filenames],
    height=size.height,
    width=size.width,
)
grid_image.save(config["neuron_atlas"]["result_filename"])
