import jax.numpy as jnp
from jax import random, jit
import os
import matplotlib
#matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import argparse
from plot_utils import HandlerRect

from jax.config import config

from data import MNIST, RandomData, CIFAR10
from utils import acc_multi
from kernels import NeuralTangent, NeuralTangentConv, RandomFeatures

parser = argparse.ArgumentParser()
parser.add_argument('--n_train', default=1000, type=int)
parser.add_argument('--n_test', default=1000, type=int)
parser.add_argument('--classes', default=10, type=int)
parser.add_argument('--num', default=10, type=int)
parser.add_argument('--runs', default=1, type=int)
parser.add_argument('--dataset', default='MNIST', type=str)
parser.add_argument('--reg', default=0, type=float)
parser.add_argument('--model', default='NTK', type=str)
parser.add_argument('--width', default=None, type=int)

args = parser.parse_args()

save_path = os.path.dirname(os.path.realpath(__file__)) + '/store/plots/'

config.update("jax_enable_x64", True)

key = random.PRNGKey(200)
key, train_key, test_key, feat_key, noise_key = random.split(key, 5)

n_test = args.n_test
n_train = args.n_train

if args.reg == 0:
    reg = None
else:
    reg = args.reg

if args.dataset == 'MNIST' or args.dataset == 'FashionMNIST':
    dim = 784
if args.dataset == 'CIFAR10':
    dim = 3072

if args.classes == 10:
    one_hot = True
    binary = False
else:
    one_hot = False
    binary = True

runs = args.runs
noise_keys = random.split(noise_key, runs)

train_losses, test_losses, train_accs, test_accs, crosses, cross_accs = [], [], [], [], [], []
train_l_vars, test_l_vars, train_accs_vars, test_accs_vars, crosses_vars, cross_accs_vars = [], [], [], [], [], []
num_correct = []

noise_levels = jnp.linspace(0, 1, args.num)

if args.dataset == 'MNIST':
    data = MNIST(n_train=n_train, n_test=n_test, binary=binary, flat=True, one_hot=one_hot)
if args.dataset == 'CIFAR10':
    data = CIFAR10(n_train=n_train, n_test=n_test, binary=binary, flat=True, one_hot=one_hot)

#model = NeuralTangentConv(dim=dim, widths=[100, 100], filter_shapes=[[5, 5], [5, 5]], strides=[[3, 3], [3, 3]],
#                          x_train=data.x_train, x_test=data.x_test, depth=3, reg=None, classes=10)
if args.model == 'RF':
    m = 50
    Z = random.normal(key=key, shape=(dim, m))
    model = RandomFeatures(dim=dim, m=m, key=key, x_train=data.x_train, x_test=data.x_test, reg=reg, Z=Z)
if args.model == 'NTK':
    model = NeuralTangent(dim=dim, x_train=data.x_train, x_test=data.x_test, depth=5, reg=reg, classes=args.classes)

for i in range(len(noise_levels)):
    if args.dataset == 'MNIST':
        data = MNIST(n_train=n_train, n_test=n_test, binary=binary, flat=True, one_hot=one_hot)
    if args.dataset == 'CIFAR10':
        data = CIFAR10(n_train=n_train, n_test=n_test, binary=binary, flat=True, one_hot=one_hot)

    data.randomize(noise_level=noise_levels[i], key=noise_keys[i],  runs=runs)

    # Train error
    pred_train = model.predict_fn(x=data.x_train, labels=data.y_train, t=jnp.inf, train=True)
    loss_train = 1/n_train * jnp.sum(jnp.square(data.y_train - pred_train), axis=(1, 2))
    train_losses.append(jnp.mean(loss_train))
    train_l_vars.append(jnp.std(loss_train))

    acc_train = acc_multi(data.y_train, pred_train)
    train_accs.append(jnp.mean(acc_train))
    train_accs_vars.append(jnp.std(acc_train))

    # Test error
    pred_test = model.predict_fn(x=data.x_test, labels=data.y_train, t=jnp.inf, test=True)
    loss_test = 1/n_test * jnp.sum(jnp.square(data.y_test - pred_test), axis=(1, 2))
    test_losses.append(jnp.mean(loss_test))
    test_l_vars.append(jnp.std(loss_test))

    acc_test = acc_multi(data.y_test, pred_test)
    test_accs.append(jnp.mean(acc_test))
    test_accs_vars.append(jnp.std(acc_test))

    # Cross validation error
    cross, cross_acc = model.leave_one_out(labels=data.y_train, pred_labels=data.y_correct)
    crosses.append(jnp.mean(cross))
    cross_accs.append(jnp.mean(cross_acc))
    crosses_vars.append(jnp.std(cross))
    cross_accs_vars.append(jnp.std(cross_acc))

    print(noise_levels[i])

sns.set(font_scale=1.3)
sns.set_style('whitegrid')

rect1 = patches.Rectangle((0, 0), 2, 2, facecolor='#ff7f78')
rect2 = patches.Rectangle((0, 0), 2, 2, facecolor='#439c74')

plt.plot(noise_levels, test_losses, c='#ff7f78', alpha=0.4)
#plt.fill_between(x=noise_levels, y1=[test_l_vars[i] + test_losses[i] for i in range(len(noise_levels))],
#                 y2=[test_losses[i] - test_l_vars[i] for i in range(len(noise_levels))], alpha=0.2, color='#ff7f78')
plt.fill_between(noise_levels, test_losses, alpha=0.2, color='#ff7f78')
plt.plot(noise_levels, crosses, c='#439c74')
plt.fill_between(x=noise_levels, y1=[crosses_vars[i] + crosses[i] for i in range(len(noise_levels))],
                 y2=[crosses[i] - crosses_vars[i] for i in range(len(noise_levels))], alpha=0.2, color='#439c74')

plt.legend(labels=['Test Loss', 'LOO Loss'], loc='upper left',
               bbox_to_anchor=(0.0, 1.125),
               ncol=2, fancybox=False, shadow=False, frameon=False, handles=(rect1, rect2), handler_map={
                patches.Rectangle: HandlerRect()})
plt.xlabel('Amount of Noise')
plt.ylabel('Loss')
plt.savefig(save_path + 'noisy_losses' + args.dataset, dpi=500, bbox_inches="tight")
plt.show()
plt.close()


plt.plot(noise_levels, test_accs, '#ff7f78', alpha=0.4)
plt.fill_between(noise_levels, y1=test_accs, alpha=0.2, color='#ff7f78')

plt.plot(noise_levels, cross_accs, '#439c74')
plt.fill_between(x=noise_levels, y1=[cross_accs_vars[i] + cross_accs[i] for i in range(len(noise_levels))],
                 y2=[cross_accs[i] - cross_accs_vars[i] for i in range(len(noise_levels))], alpha=0.2, color='#439c74')

plt.xlabel('Amount of Noise')
plt.ylabel('Accuracy')
plt.legend(labels=['Test Accuracy', 'LOO Accuracy'], loc='upper left',
               bbox_to_anchor=(0.0, 1.125),
               ncol=2, fancybox=False, shadow=False, frameon=False, handles=(rect1, rect2), handler_map={
                patches.Rectangle: HandlerRect()})
plt.savefig(save_path + 'noisy_accs' + args.dataset, dpi=500, bbox_inches="tight")
plt.show()
plt.close()
