from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from rf_helper import *
from visualize_helper import *


input_resolution = 224
resize_size = int((256 / 224) * input_resolution)


transform = transforms.Compose([
    transforms.Resize((resize_size, resize_size), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(input_resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

layer_list = ['V1.conv1', 'V1.conv2', 'V2.conv1', 'V2.conv2', 'V2.conv3', 'V4.conv1', 'V4.conv2', 'V4.conv3', 'IT.conv1', 'IT.conv2', 'IT.conv3']
scale_list = [16, 8, 4, 4, 4, 2, 2, 2, 1, 1, 1]


dataset = ImageFolder('...', transform=transform)  # path to the dataset
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

model = ...  # load CORnet-S

hook_dict = OrderedDict()
layers = [model.V1.conv1, model.V1.conv2, model.V2.conv1, model.V2.conv2, model.V2.conv3, model.V4.conv1, model.V4.conv2, model.V4.conv3, model.IT.conv1, model.IT.conv2, model.IT.conv3]

def get_hook(name):
    """Returns a hook function that saves the output of the layer to the outputs dictionary."""
    def hook(module, input, output):
        hook_dict[name] = output  # Detach the output tensor to avoid saving the computation graph
    return hook

for i, layer in enumerate(layers):
    layer.register_forward_hook(get_hook(layer_list[i]))

analysis_output_dict = {}
analysis_output_dict['cornet_s'] = analysis_single_layer(
    model=model, model_type='cornet_s',
    data_loader=data_loader, 
    layer_list=layer_list,
    scale_list=scale_list,
    hook_dict=hook_dict, 
    max_image_num=None, device="cuda"
)

average_image_dict_average, hook_dict = analysis_output_dict['cornet_s']

rfs = create_plots(
    model_type='cornet_s',
    num_layers=11,
    pixel_of_interest=[],
    average_image_dict=average_image_dict_average,
    input_output_plane='output',
    use_attention=True,
    use_jacobian=True,
    hook_dict=hook_dict,
    use_middle=True,
    path=f'...'  # output directory
)

print('Effective receptive field sizes:', ','.join([str(rf) for rf in rfs]))
