import torch
from matplotlib import pyplot as plt

from mind_the_pad.paths import plot_folder
from mind_the_pad.model_analysis.visualize_intermediate_results import generator_conv_relu_output, relu_output_for_imshow
from mind_the_pad.train_mnist.utils import num_conv_relu_layers

intermediate_relu_folder = plot_folder / 'intermediate_relu_ssd_coco'
if not intermediate_relu_folder.exists(): intermediate_relu_folder.mkdir()

def main():
    ssd_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd')

    fig, axs = plt.subplots(1, num_conv_relu_layers(ssd_model))
    for i, (lname, relu_output) in enumerate(generator_conv_relu_output(ssd_model, 640, 480, 3)):
        o = relu_output_for_imshow(relu_output)
        axs[i].imshow(o, cmap='viridis')
        axs[i].set_title(lname)
        fig1, ax = plt.subplots()
        ax.imshow(o, cmap='viridis')
        fig1.tight_layout()
        fig1.savefig(intermediate_relu_folder / f'layer_{i}.png')
        plt.close(fig1)
    fig.savefig(plot_folder / 'ssd_coco_relu_outputs.png')

if __name__ == '__main__':
    main()