from nesim.utils.feature_vis.generator import (
    ConvLayerFeaturevisGenerator,
    LinearLayerFeaturevisGenerator,
)
from neuro.models.neuro_model.config import NeuroModelConfig

from nesim.utils.json_stuff import load_json_as_dict
from neuro.models.neuro_model.model import NeuroModel
from neuro.utils.model_building import make_convmapper_config_from_size_sequences
from PIL import Image
from nesim.utils.grid_size import find_rectangle_dimensions
import os
import argparse
from nesim.utils.getting_modules import get_module_by_name
from nesim.grid.two_dimensional import BaseGrid2dConv
import matplotlib.pyplot as plt
from nesim.utils.normalization import normalize_tensor

parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--config", type=str, help="config file")
parser.add_argument("--start-checkpoint", type=str, help="start checkpoint file")
parser.add_argument("--end-checkpoint", type=str, help="end checkpoint file")
parser.add_argument("--layer-name", type=str, help="end checkpoint file")
parser.add_argument("--output-filename", type=str, help="output image filename")

args = parser.parse_args()
config = load_json_as_dict(filename=args.config)
assert os.path.exists(args.start_checkpoint)
assert os.path.exists(args.end_checkpoint)


conv_mapper_config = make_convmapper_config_from_size_sequences(
    conv_layer_size_sequence=config["model"]["conv_layer_size_sequence"],
    linear_layer_size_sequence=config["model"]["linear_layer_size_sequence"],
    reduce_fn=config["model"]["conv_mapper_reduce_fn"],
    activation=config["model"]["conv_mapper_activation"],
    conv_layer_kernel_size=config["model"]["conv_mapper_kernel_size"],
)

neuro_model_config = NeuroModelConfig(
    brain_response_predictor_config=conv_mapper_config,
    image_encoder=config["model"]["neuro_model_config_image_encoder"],
    hook_layer_name=config["model"]["intermediate_layer_name"],
    reduce_fn=config["model"]["neuro_model_reduce_fn"],
)


model = NeuroModel(config=neuro_model_config, device=config["model"]["device"])
model.brain_response_predictor.load(args.start_checkpoint)

start_layer = get_module_by_name(
    module=model.brain_response_predictor, name=args.layer_name
)

model = NeuroModel(config=neuro_model_config, device=config["model"]["device"])
model.brain_response_predictor.load(args.end_checkpoint)
end_layer = get_module_by_name(
    module=model.brain_response_predictor, name=args.layer_name
)

size = find_rectangle_dimensions(area=end_layer.out_channels)

start_grid_container = BaseGrid2dConv(
    conv_layer=start_layer, height=size.height, width=size.width, device="cuda:0"
)
end_grid_container = BaseGrid2dConv(
    conv_layer=end_layer, height=size.height, width=size.width, device="cuda:0"
)

## e-s / |s|
# delta_weight_grid = ((end_grid_container.grid - start_grid_container.grid).mean(-1) / start_grid_container.grid.abs().mean(-1))

## e-s / |e|
# delta_weight_grid = ((end_grid_container.grid - start_grid_container.grid).mean(-1) / end_grid_container.grid.abs().mean(-1))
# delta_weight_grid = ((end_grid_container.grid - start_grid_container.grid) / end_grid_container.grid.abs()).mean(-1)


## e-s
delta_weight_grid = (end_grid_container.grid - start_grid_container.grid).mean(-1)


## magnitude increase
# delta_weight_grid = (end_grid_container.grid.abs() - start_grid_container.grid.abs()).sum(-1)

## large final
# delta_weight_grid = (end_grid_container.grid.abs()).mean(-1)

## distance travelled
# delta_weight_grid = (end_grid_container.grid - start_grid_container.grid).abs().sum(-1)

delta_weight_grid = delta_weight_grid.detach().cpu().numpy()

fig = plt.figure(figsize=(10, 10))
# plt.imshow(normalize_tensor(delta_weight_grid.detach().cpu()))
plt.imshow(delta_weight_grid, cmap="coolwarm")

plt.colorbar()
plt.title(
    f"end_weights - start_weights\nlayer: conv_mapper.{args.layer_name}\nstart: {args.start_checkpoint}\nend: {args.end_checkpoint}",
    fontsize=17,
)
plt.show()
fig.savefig(args.output_filename)

"""
python3 delta_weights.py --config configs/ours.json \
--start-checkpoint step_checkpoints/ours/train_step_idx_0.pth \
--end-checkpoint step_checkpoints/ours/train_step_idx_57000.pth \
--layer-name conv_layers.0 \
--output-filename delta_weights_ours.jpg


python3 delta_weights.py --config configs/baseline.json \
--start-checkpoint step_checkpoints/baseline/train_step_idx_0.pth \
--end-checkpoint step_checkpoints/baseline/train_step_idx_57000.pth \
--layer-name conv_layers.0 \
--output-filename delta_weights_baseline.jpg
"""
