import pandas as pd
import torch
from matplotlib import pyplot as plt

from mind_the_pad.model_analysis.shift_images import calc_shifts_pixel_based, shifted_image
from mind_the_pad.data.mnist import preprocessor_mnist
from mind_the_pad.train_mnist.utils import iter_mnist_exprmnt_with_model_loaded
from torchvision.datasets.mnist import MNIST
from mind_the_pad.paths import dataset_folder, plot_folder
import seaborn as sns


@torch.no_grad()
def main():
    dataset = MNIST(dataset_folder, train=False)
    test_images = {}
    iter_dataset = iter(dataset)
    preprocessing_img = preprocessor_mnist()
    while len(test_images) < 10:
        X, y = next(iter_dataset)
        if y not in test_images:
            test_images[y] = preprocessing_img(X)
    test_images = list(test_images.items())
    shifts_images = [calc_shifts_pixel_based(x, x.min()) for y, x in test_images]
    print(shifts_images[0])
    # image_plot = test_images[0][1]
    # for (dx, dy) in shifts_images[0]:
    #     plt.imshow(affine(image_plot, 0.0, [dx, dy], 1.0, [0.0, 0.0], fill=image_plot.min()).permute(1, 2, 0))
    #     plt.title(f"{(dx, dy)}")
    #     plt.show()
    results = []
    for experiment in iter_mnist_exprmnt_with_model_loaded():
        dst_folder = experiment.path / 'shifted_outputs'
        if not dst_folder.exists(): dst_folder.mkdir()
        experiment.model.eval()
        padding, padding_mode = experiment.data['padding'], experiment.data['padding_mode']
        for (y, x), shifts in zip(test_images, shifts_images):
            p_y_shifts = []
            for shifted_x, shift in shifted_image(x, shifts):
                p_y = experiment.model(shifted_x.unsqueeze(0)).softmax(dim=-1)
                p_y_shifts.append(p_y)
            p_y_shifts = torch.cat(p_y_shifts, dim=0)
            fig, ax = plt.subplots()
            sns.heatmap(p_y_shifts, ax=ax, cbar=True)
            ax.set_yticklabels(shifts, fontsize=8)
            ax.set_ylabel('Shifts')
            ax.set_xlabel('$\hat{p(y)}$')
            fig.savefig(dst_folder / f'{y}.png')
            plt.close()
            errors = len(p_y_shifts) - (p_y_shifts.argmax(dim=-1) == y).sum()
            errors = errors.item()
            results.append((padding, padding_mode, errors))
    results = pd.DataFrame(results, columns=['padding', 'padding_mode', 'errors'])
    # results['padding_mode'] = results['padding_mode'].apply(lambda x: x if x else 'null')

    avg_results = results.pivot_table('errors', 'padding', 'padding_mode')
    ax = sns.heatmap(avg_results, annot=True)
    ax.set_xlabel('Padding')
    ax.set_ylabel('Padding mode')
    plt.title(f'Average error over {len(shifts_images[0])} shifts')
    plt.savefig(plot_folder / 'mnist_average_error_by_padding.png')
    plt.close()

    std_results = results.pivot_table('errors', 'padding', 'padding_mode', 'std')
    ax = sns.heatmap(std_results, annot=True)
    ax.set_xlabel('Padding')
    ax.set_ylabel('Padding mode')
    plt.title(f'Std error over {len(shifts_images[0])} shifts')
    plt.savefig(plot_folder / 'mnist_std_error_by_padding.png')
    plt.close()

if __name__ == '__main__':
    main()
