from matplotlib import pyplot as plt

from mind_the_pad.data.mnist import letters_mnist_test_dataset
from mind_the_pad.model_analysis.visualize_intermediate_results import generator_conv_relu_output, \
    relu_output_for_imshow
from mind_the_pad.train_mnist.utils import iter_emnist_exprmnt_with_model_loaded_same_only
from mind_the_pad.paths import plot_folder

dest_folder = plot_folder / 'emnist_circular_intemediate_output'
if not dest_folder.exists():
    dest_folder.mkdir()


def main():
    test_dataset = letters_mnist_test_dataset()
    X, y = next(iter(test_dataset))
    model = get_first_trained_model_circular_padding_mode()
    for i, (lname, relu_output) in enumerate(generator_conv_relu_output(model, X.unsqueeze(0))):
        output = relu_output_for_imshow(relu_output)
        plt.imshow(output)
        plt.tight_layout()
        plt.savefig(dest_folder / f'layer_{i}.png')
        plt.close()


def get_first_trained_model_circular_padding_mode():
    for experiment in iter_emnist_exprmnt_with_model_loaded_same_only('cpu'):
        padding_mode = experiment.data['padding_mode']
        if padding_mode != 'circular':
            continue  # skip models without padding mode == circular
        print('evaluating model at path', experiment.path)
        return experiment.model


if __name__ == '__main__':
    main()
