"""
plot example iamges and their predictions and certainties
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import os.path as osp

import argparse

import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import method_colors, method_linestyles, method_markers, method_names

from models.neural_de import neural_de
from models.modules import downsampling, convolutions, fc_layers
from datasets.datasets import image_data

parser = argparse.ArgumentParser()
parser.add_argument('--experiment', type=str, choices=['mnist', 'cifar10', 'svhn'], default='mnist')
parser.add_argument('--width', type=int, default=50)
parser.add_argument('--state_width', type=int, default=10)
parser.add_argument('--nimages', type=int, default=4)
args = parser.parse_args()


sns.set_style('whitegrid')
rc('font', family='serif')

# get data
data = image_data(dataset=args.experiment, train=False)

# chosse which images to test on
possible_indices = np.arange(len(data))
choices = np.random.choice(possible_indices, size=args.nimages, replace=False, p=None)

# select values based on the images being used
if args.experiment == 'mnist':
    in_channels = 1
    nhidden = 50
    nclasses = 10
    classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
elif args.experiment == 'cifar10':
    in_channels = 3
    nhidden = 50
    nclasses = 10
    classes = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
elif args.experiment == 'svhn':
    in_channels = 3
    nhidden = 50
    nclasses = 10
    classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
kwargs = {'rtol': 1e-3, 'atol': 1e-3}


encoder = downsampling(in_channels, args.state_width)
shape, vector_size = encoder.get_shape(data[0][0].unsqueeze(0))
decoder = fc_layers(vector_size, nhidden, nclasses, True)
drift = convolutions(args.state_width, args.width, shape)
diffusion = None
model = neural_de(drift, diffusion, encoder, decoder, backprop_option='adjoint_gq')


folder = osp.join('results', args.experiment, '16', '50', 'adjoint_gq', '3')
model.load_state_dict(torch.load(osp.join(folder, 'trained_model.pth'), map_location='cpu'))


# add to the plots:
def add_image(index):
    x, t, y = data[choices[index]]
    x = x.unsqueeze(0)
    logits = model.evaluate(x, t, **kwargs)
    probabilities = F.softmax(logits, dim=1)
    pred = int(torch.argmax(probabilities).item())
    prob = 100*probabilities[0][pred]
    pred = classes[pred]
    true = classes[y]
    img_numpy = x.squeeze().numpy()
    ax = plt.subplot(height, width, i+1)
    if args.experiment == 'mnist':
        plt.imshow(img_numpy, cmap='Greys_r')
    else:
        img_numpy = np.moveaxis(img_numpy, 0, -1)
        plt.imshow(img_numpy)
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    plt.xlabel('Prediction: {}\nConfidence: {:.4f}%\nTrue Class: {}'.format(pred, prob, true), fontsize=axis_fontsize)
    

height = 1
width = args.nimages
axis_fontsize = 16


fig = plt.figure(figsize=[4*width, 4*height])
fig.subplots_adjust(hspace=0.0, wspace=0.1)


for i in range(args.nimages):
    add_image(i)

# save figure
plt.savefig(osp.join('plotting', 'plots', args.experiment+'_preds.pdf'), bbox_inches='tight')