import argparse
from operator import itemgetter
from typing import Union

import pytorch_lightning

import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader

from mind_the_pad.paths import dataset_folder
import torch
from mind_the_pad.paths import plot_folder, metrics_folder
from mind_the_pad.data.mnist import letters_mnist_test_dataset
from mind_the_pad.model_analysis.shift_images import calc_shifts_pixel_based, shift_image
from mind_the_pad.ssd_coco.eval_ssd import set_padding_mode_
from mind_the_pad.train_mnist.model import MnistPadding
from mind_the_pad.train_mnist.utils import iter_emnist_exprmnt_with_model_loaded, iter_emnist_exprmnt
import seaborn as sns
import matplotlib.pyplot as plt


class ShiftedEMNIST(Dataset):

    def __init__(self, samples_per_y=10):
        dataset = letters_mnist_test_dataset()
        y_samples = dict()
        samples_per_y = samples_per_y
        iter_dataset = iter(dataset)
        while len(y_samples) < 26 or not all(len(samples) == samples_per_y for samples in y_samples.values()):
            X, y = next(iter_dataset)
            if y not in y_samples:
                y_samples[y] = [X]
            elif len(y_samples[y]) < samples_per_y:
                y_samples[y].append(X)
        self.samples = [(X, y, calc_shifts_pixel_based(X, X.min())) for y, Xs in y_samples.items() for X in Xs]

    def __len__(self):
        return len(self.samples) * 9

    def __getitem__(self, i):
        i_sample = i // 9
        i_shift = i % 9
        X, y, shifts = self.samples[i_sample]
        shift = shifts[i_shift]
        X_shifted = shift_image(X, shift)
        return X_shifted, y, shift


def plot_shifted_images_sample(emnist_samples):
    """
    Plot an example of shifted images in a 3x3 grid
    """
    iterator = iter(emnist_samples)
    for _ in range(10):
        [next(iterator) for _ in range(9)]
    samples = [next(iterator) for _ in range(9)]
    samples = sorted(samples, key=lambda x: x[2][0] + x[2][1] * 100)
    fig, axs = plt.subplots(3, 3)
    axs = axs.reshape([-1])
    vshifts = ['top', '', 'down']
    hshifts = ['left', '', 'right']
    for i, (x_shifted, y, shift) in enumerate(samples):
        axs[i].set_title(f'{vshifts[i // 3]} {hshifts[i % 3]}')
        axs[i].imshow(x_shifted.permute(1, 2, 0), cmap='binary')
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    fig.tight_layout()
    fig.savefig(plot_folder / 'shifted_image_sample.png')
    plt.close()


def compute_dataset_or_load():
    dataset_path = dataset_folder / 'eval_emnist.csv'
    emnist_samples = ShiftedEMNIST()
    result_folder = plot_folder / 'emnist_shifts'
    if not result_folder.exists():
        result_folder.mkdir()
    if not dataset_path.exists():
        dataset = []
        for exprmnt_data in iter_emnist_exprmnt_with_model_loaded('cuda'):
            model = exprmnt_data.model.to('cpu').eval()
            padding_mode = exprmnt_data.data['padding_mode']
            random_pad_input = exprmnt_data.data['random_pad_input']
            print(str(exprmnt_data.path))
            has_bn = exprmnt_data.data['batch_norm']
            for (x_shifted, y, shift) in emnist_samples:
                y -= 1
                x_shifted = x_shifted.unsqueeze(0)
                y_pred = model(x_shifted).softmax(dim=-1)
                row = dict(y_pred=y_pred, shift=shift, y=y, batch_norm=has_bn,
                           input_pad=random_pad_input, padding_mode=padding_mode)  # hyperparameters => y_preds
                row['perplexity'] = calc_perplexity(y_pred, y).item()
                row['error'] = calc_prediction_error(y_pred, y).item()
                del row['y_pred']
                dataset.append(row)
        dataset = pd.DataFrame(dataset)
        dataset.to_csv(dataset_path, index=False)
    else:
        print('loading from', str(dataset_path))
        dataset = pd.read_csv(dataset_path)
        dataset['shift'] = dataset['shift'].apply(eval)
    return dataset, emnist_samples


def avg_error_by_shift_padding(dataset):
    aggr = dataset.pivot_table('error', 'shift', 'padding_mode')
    print(aggr)
    aggr = aggr.reset_index()
    aggr['x_shift'] = aggr['shift'].apply(itemgetter(0))
    aggr['y_shift'] = aggr['shift'].apply(itemgetter(1))
    x_shifted = aggr['x_shift']
    x_translation = x_shifted.min()
    x_shifted = x_shifted - x_translation
    y_shifted = aggr['y_shift']
    y_translation = y_shifted.min()
    y_shifted = y_shifted - y_translation
    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(16, 9))
    axs = axs.ravel()
    for i, padding_mode in enumerate(['circular', 'reflect', 'replicate', 'zeros']):
        avg_error_col = aggr[padding_mode]
        heatmap_shifts = np.zeros((x_shifted.max() + 1, y_shifted.max() + 1))
        for i_shift in range(len(x_shifted)):
            avg_error = avg_error_col[i_shift]
            heatmap_shifts[x_shifted[i_shift], y_shifted[i_shift]] = avg_error
        plt.subplots(figsize=(10, 6))
        plt.imshow(heatmap_shifts)
        plt.yticks(list(range(x_shifted.max() + 1)),
                   labels=list(range(x_translation, x_shifted.max() + x_translation + 1)))
        plt.xticks(list(range(y_shifted.max() + 1)),
                   labels=list(range(y_translation, y_shifted.max() + y_translation + 1)))
        plt.ylabel('Vertical shift')
        plt.xlabel('Horizontal shift')
        plt.title(f'Average error by shift with padding mode = {padding_mode}')
        plt.colorbar(plt.cm.ScalarMappable(plt.Normalize(avg_error_col.min(), avg_error_col.max())))
        plt.savefig(plot_folder / f'heatmap_error_by_shift_{padding_mode=}.png')
        plt.close()
        axs[i].set_title(f'padding mode = {padding_mode}', fontsize='x-large')
        axs[i].imshow(heatmap_shifts)
        if i == 0:
            axs[i].set_yticks(list(range(x_shifted.max() + 1)))
            axs[i].set_yticklabels(list(range(x_translation, x_shifted.max() + x_translation + 1)))
            axs[i].set_ylabel('Vertical shift')
        else:
            axs[i].set_yticks([])
        axs[i].set_xticks(list(range(y_shifted.max() + 1)))
        axs[i].set_xticklabels(list(range(y_translation, y_shifted.max() + y_translation + 1)))
        axs[i].set_xlabel('Horizontal shift')

    fig.colorbar(plt.cm.ScalarMappable(plt.Normalize(avg_error_col.min(), avg_error_col.max())), ax=axs.ravel().tolist())
    fig.suptitle('Prediction error by letter shift and padding mode', fontsize='xx-large')
    # fig.tight_layout()
    fig.savefig(plot_folder / 'heatmap_error_by_shift_padding_mode.png')


def average_error_by_shift_without_batchnorm(dataset):
    avg_error_per_shift = dataset[~dataset['batch_norm']].pivot_table('error', 'shift', aggfunc=['mean', 'count'])
    avg_error_per_shift.to_markdown(metrics_folder / 'avg_error_by_shift.md')
    avg_error_per_shift = avg_error_per_shift.reset_index()
    avg_error_per_shift['shift'] = avg_error_per_shift['shift'].apply(eval)
    avg_error_per_shift['abs_distance'] = avg_error_per_shift['shift'].apply(lambda x: abs(x[0]) + abs(x[1]))
    avg_error_col = avg_error_per_shift['mean']['error']
    sns.regplot(x=avg_error_per_shift['abs_distance'], y=avg_error_col)
    plt.savefig(plot_folder / 'abs_distance_mean_error.png')
    plt.close()
    del avg_error_per_shift


def calc_perplexity_by_batchnorm_input_pad(dataset):
    grouped_perplexity = dataset.pivot_table('perplexity', 'batch_norm', 'input_pad', margins=True)
    print(grouped_perplexity)
    grouped_perplexity.to_markdown(metrics_folder / 'avg_perplexity_shifts_emnist.md', index=True)
    grouped_perplexity.to_csv(metrics_folder / 'avg_perplexity_shifts_emnist.csv')
    grouped_perplexity.to_latex(metrics_folder / 'avg_perplexity_shifts_emnist.tex')


def plot_avg_error_by_shift_input_pad_bn(dataset):
    dataset['batch_norm'] = dataset['batch_norm'].apply(lambda x: 'bn' if x else 'no_bn')
    dataset['input_size'] = dataset['input_pad'] + 28
    avg_error_shift_by_padding_bn = dataset.pivot_table('error', 'shift', ['batch_norm', 'input_size'], aggfunc='mean')
    avg_error_shift_by_padding_bn = avg_error_shift_by_padding_bn.reset_index()
    avg_error_shift_by_padding_bn['shift'] = avg_error_shift_by_padding_bn['shift'].apply(eval)
    avg_error_shift_by_padding_bn['x_shift'] = avg_error_shift_by_padding_bn['shift'].apply(itemgetter(0))
    avg_error_shift_by_padding_bn['y_shift'] = avg_error_shift_by_padding_bn['shift'].apply(itemgetter(1))
    x_shifted = avg_error_shift_by_padding_bn['x_shift']
    x_translation = x_shifted.min()
    x_shifted = x_shifted - x_translation
    y_shifted = avg_error_shift_by_padding_bn['y_shift']
    y_translation = y_shifted.min()
    y_shifted = y_shifted - y_translation
    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(20, 12))
    for i, bn in enumerate(['bn', 'no_bn']):
        for j, input_size in enumerate([28, 29]):
            avg_error_col = avg_error_shift_by_padding_bn[bn][input_size]
            heatmap_shifts = np.zeros((x_shifted.max() + 1, y_shifted.max() + 1))
            for i_shift in range(len(x_shifted)):
                avg_error = avg_error_col[i_shift]
                heatmap_shifts[x_shifted[i_shift], y_shifted[i_shift]] = avg_error
            plt.subplots(figsize=(10, 6))
            plt.imshow(heatmap_shifts)
            plt.yticks(list(range(x_shifted.max() + 1)),
                       labels=list(range(x_translation, x_shifted.max() + x_translation + 1)))
            plt.xticks(list(range(y_shifted.max() + 1)),
                       labels=list(range(y_translation, y_shifted.max() + y_translation + 1)))
            plt.ylabel('Vertical shift')
            plt.xlabel('Horizontal shift')
            plt.title('Average error by shift')
            plt.colorbar(plt.cm.ScalarMappable(plt.Normalize(avg_error_col.min(), avg_error_col.max())))
            plt.savefig(plot_folder / f'heatmap_error_by_shift_{bn=}_{input_size=}.png')
            plt.close()
            axs[i, j].set_title(f'With{"" if bn == "bn" else "out"} batch norm and input size = {input_size}',
                                fontsize='xx-large')
            axs[i, j].imshow(heatmap_shifts)
            if j == 0:
                axs[i, j].set_yticks(list(range(x_shifted.max() + 1)))
                axs[i, j].set_yticklabels(list(range(x_translation, x_shifted.max() + x_translation + 1)))
                axs[i, j].set_ylabel('Vertical shift')
            else:
                axs[i, j].set_yticks([])
            if i == (axs.shape[0] - 1):
                axs[i, j].set_xticks(list(range(y_shifted.max() + 1)))
                axs[i, j].set_xticklabels(list(range(y_translation, y_shifted.max() + y_translation + 1)))
                axs[i, j].set_xlabel('Horizontal shift')
            else:
                axs[i, j].set_xticks([])
    fig.colorbar(plt.cm.ScalarMappable(plt.Normalize(avg_error_col.min(), avg_error_col.max())),
                 ax=axs.ravel().tolist())
    fig.suptitle('Prediction error by letter shift, batch norm and input size', fontsize='xx-large')
    fig.savefig(plot_folder / 'heatmap_error_by_shift_bn_input_size.png')


from torch.nn.functional import cross_entropy


def calc_perplexity(y_pred: torch.Tensor, y_true: Union[torch.Tensor, int]):
    """
    Compute perplexity across batch
    :param y_pred:
    :param label: true label
    :return:
    """
    if isinstance(y_true, int):
        y = torch.zeros(y_pred.shape[0], dtype=torch.long)
        y[:] = y_true
        y_true = y
        del y
    assert len(y_pred.shape) == 2
    assert len(y_true.shape) == 1
    assert y_true.shape[0] == y_pred.shape[0]
    ce = cross_entropy(y_pred, y_true, reduction='none')
    perplexity = ce.exp()
    return perplexity


def calc_prediction_error(y_pred: torch.Tensor, y_true: int):
    return (y_pred.argmax(dim=-1) != y_true).float()


def plots_y_preds(result_folder, y_preds_gathered, y_shifts):
    for (y, has_bn, random_pad_input), y_preds_list in y_preds_gathered.items():
        shifts = y_shifts[y]
        avg_y_preds = torch.stack(y_preds_list, dim=0).mean(dim=0)
        ax = sns.heatmap(avg_y_preds, annot=False)
        ax.set_yticklabels(shifts, fontsize=8)
        ax.set_ylabel('Shift')
        ax.set_xlabel('$\hat{p(y)}$')
        plt.title(
            f'{y=} with{"" if has_bn else "out"} batch norm and input size {(28 + random_pad_input, 28 + random_pad_input)}')
        plt.savefig(result_folder / f'{y=}_{has_bn=}_{random_pad_input=}.png')
        plt.close()


def eval_model_test_time_only_wo_uneven_padding():
    test_dataset = letters_mnist_test_dataset()
    test_dl = DataLoader(test_dataset, batch_size=32)
    dataset = []
    for exprmnt in iter_emnist_exprmnt():
        if exprmnt.data['random_pad_input'] != 0:
            continue
        padding_type = exprmnt.data['padding_type']
        padding_mode = exprmnt.data['padding_mode']
        if padding_mode != 'zeros' or padding_type != 'same':
            continue

        model = MnistPadding(**exprmnt.data)
        model.load_state_dict(torch.load(exprmnt.path / 'last.ckpt', 'cpu')['state_dict'])
        set_padding_mode_(model, 'circular')
        model.random_pad_input = 1
        trainer = pytorch_lightning.Trainer()
        trainer.test(model, test_dl)
        avg_accuracy = model.accuracy.compute()
        print(avg_accuracy)
        dataset.append(avg_accuracy)
        model.accuracy.reset()
    print(dataset)


def plot_cross_table_shifts(dataset):
    dataset['x_shift_category'] = dataset['shift'].apply(itemgetter(0)).apply(lambda x: 'center' if x == 0 else 'top' if x < 0 else 'bottom')
    dataset['y_shift_category'] = dataset['shift'].apply(itemgetter(1)).apply(lambda x: 'center' if x == 0 else 'left' if x < 0 else 'right')
    result = dataset.pivot_table('error', 'x_shift_category', 'y_shift_category')
    result = result.loc[['top', 'center', 'bottom'], ['left', 'center', 'right']]
    print(result)
    sns.heatmap(result, annot=True, fmt='.3%')
    plt.ylabel('Vertical shift')
    plt.xlabel('Horizontal shift')
    plt.title('Prediction error by shift type')
    plt.savefig(plot_folder / 'shifts_by_direction.png')
    plt.close()

@torch.no_grad()
def main():
    args = parse_args()

    dataset, emnist_samples = compute_dataset_or_load()

    if args.shifted_images_sample:
        plot_shifted_images_sample(emnist_samples)
    if args.err_shift_wo_batchnorm:
        average_error_by_shift_without_batchnorm(dataset)
    if args.err_shift_padding:
        avg_error_by_shift_padding(dataset)
    if args.perplexity_bn_pad:
        calc_perplexity_by_batchnorm_input_pad(dataset)
    if args.eval_model_test_time_only_wo_uneven_padding:
        eval_model_test_time_only_wo_uneven_padding()
    if args.plot_cross_table_shifts:
        plot_cross_table_shifts(dataset)


def argparser_name(f):
    return '--' + f.__name__.replace('_', '-')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--shifted-images-sample', action='store_true', help=plot_shifted_images_sample.__doc__)
    parser.add_argument('--err-shift-wo-batchnorm', action='store_true')
    parser.add_argument('--err-shift-padding', action='store_true')
    parser.add_argument('--perplexity-bn-pad', action='store_true')
    parser.add_argument('--eval-model-test-time-only-wo-uneven-padding', action='store_true')
    parser.add_argument(argparser_name(plot_cross_table_shifts), action='store_true')
    parser.add_argument('--all', action='store_true')
    args = parser.parse_args()
    print(args)
    if args.all:
        for varname in args.__dict__:
            if isinstance(getattr(args, varname), bool):
                setattr(args, varname, True)
    assert any(args.__dict__.values()), 'set at least one flag'
    return args


if __name__ == '__main__':
    main()
