import matplotlib.pyplot as plt

from mind_the_pad.model_analysis.visualize_intermediate_results import generator_conv_relu_output_dummy_image, relu_output_for_imshow
from mind_the_pad.paths import plot_folder
from mind_the_pad.train_mnist.utils import iter_mnist_exprmnt_with_model_loaded, num_conv_relu_layers


def main():
    dict_padding = {}
    for (exprm_path, exprm_data, model) in iter_mnist_exprmnt_with_model_loaded():
        print(str(exprm_path))
        padding_type = exprm_data['padding']
        padding_mode = exprm_data['padding_mode']
        dummy_image_path = exprm_path / 'dummy_relu_outputs'
        if not dummy_image_path.exists(): dummy_image_path.mkdir()
        num_cols = num_conv_relu_layers(model)
        fig, axs = plt.subplots(ncols=num_cols, figsize=(num_cols * 2.2, 1.4))
        plt.subplots()
        relu_outputs = []
        for layer_name, relu_output, ax_plot in generator_conv_relu_output_dummy_image(model, 28, 28, 1, axs):
            plt.savefig(dummy_image_path / f'{layer_name}_output.png')
            plt.close()
            plt.subplots()
            relu_outputs.append(relu_output)
        plt.close()
        dict_padding[(padding_type, padding_mode)] = relu_outputs
        fig.savefig(dummy_image_path / 'layers_output.png')
    nrows = len(dict_padding)
    ncols = max(map(len, dict_padding.values()))
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4.0 * nrows, 3.4 * ncols))
    for i, ((padding_type, padding_mode), relu_outputs) in enumerate(dict_padding.items()):
        for j, relu_output in enumerate(relu_outputs):
            axs[i, j].imshow(relu_output_for_imshow(relu_output))
            axs[i, j].set_xticks([])
            axs[i, j].set_yticks([])
        axs[i,0].set_title(f'{padding_type=} {padding_mode=}')
    fig.savefig(plot_folder / 'relu_output_dummy_images_mnist.png')

if __name__ == '__main__':
    main()