import matplotlib.pyplot as plt
import torch

from mind_the_pad.model_analysis.plot_model_parameters import heatmap_avg_kernel_filters

from mind_the_pad.train_mnist.model import build_model_by_padding_size_mode
from mind_the_pad.train_mnist.utils import iter_mnist_exprmnt, num_conv_layers


@torch.no_grad()
def plot_avg_filters():
    for exprmnt_path, exprmnt_data in iter_mnist_exprmnt():
        padding_size = exprmnt_data['padding']
        padding_mode = exprmnt_data['padding_mode']
        model = build_model_by_padding_size_mode(padding_size, padding_mode)
        model_params_path = exprmnt_path / 'model.pth'
        if not model_params_path.exists():
            print(str(exprmnt_path), "doesn't have final model.pth saved. Skipping")
            continue
        avg_kernel_folder = exprmnt_path / 'plot_avg_kernel'
        if not avg_kernel_folder.exists(): avg_kernel_folder.mkdir()
        model.load_state_dict(torch.load(model_params_path, 'cpu'))
        size: int = 0
        fig, axs = plt.subplots(ncols=num_conv_layers(model, skip_1x1_convs=True))
        _ = plt.subplots()
        for mod_name, ax_plot in heatmap_avg_kernel_filters(model, axs):
            plt.savefig(avg_kernel_folder / f'{mod_name}_avg_kernel.png')
            plt.close()
            size += 1
        fig.savefig(avg_kernel_folder / 'all_avg_kernels.png')


if __name__ == '__main__':
    plot_avg_filters()
