import jax.numpy as jnp
from jax import random
import os

import matplotlib
matplotlib.use('TkAgg')
import seaborn as sns
from jax.config import config
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import argparse

from plot_utils import HandlerRect
from kernels import RandomFeatures
from data import MNIST, RandomData, CIFAR10

parser = argparse.ArgumentParser()
parser.add_argument('--n_train', default=10, type=int)
parser.add_argument('--n_test', default=10000, type=int)
parser.add_argument('--classes', default=10, type=int)
parser.add_argument('--num', default=10, type=int)
parser.add_argument('--runs', default=1, type=int)
parser.add_argument('--dataset', default='MNIST', type=str)
parser.add_argument('--reg', default=0, type=float)

args = parser.parse_args()
config.update("jax_enable_x64", True)
save_path = os.path.dirname(os.path.realpath(__file__)) + '/store/plots/'

runs = args.runs

if args.reg == 0:
    reg = None
else:
    reg = args.reg

if args.dataset == 'MNIST' or args.dataset == 'FashionMNIST':
    dim = 784
if args.dataset == 'CIFAR10':
    dim = 3072

if args.classes == 10:
    one_hot = True
    binary = False
else:
    one_hot = False
    binary = True

n_train = args.n_train
n_test = args.n_test
widths = list(jnp.linspace(1, n_train * 0.9, num=args.num)) + list(jnp.linspace(n_train * 0.9 + 1, n_train * 1.1, num=args.num)) + \
         list(jnp.linspace(n_train * 1.1 + 1, 2 * n_train, num=args.num))
widths = [int(m) for m in widths]

K = args.classes

key = random.PRNGKey(200)
key, data_key = random.split(key, 2)

if args.dataset == 'MNIST':
    data = MNIST(n_train=n_train, n_test=n_test, binary=binary, flat=True, one_hot=one_hot)
if args.dataset == 'CIFAR10':
    data = CIFAR10(n_train=n_train, n_test=n_test, binary=binary, flat=True, one_hot=one_hot)

data.randomize(noise_level=1.0, runs=1, key=key)
train_losses, train_accs, test_losses, test_accs, cross_losses, cross_accs = [], [], [], [], [], []

Zs = random.normal(key=key, shape=(dim, max(widths), runs))

for m in widths:
    train_loss_inter, train_acc_inter, test_loss_inter, test_acc_inter, cross_inter, cross_acc_inter = [], [], [], [], [], []

    for i in range(runs):
        model = RandomFeatures(m=m, dim=dim, key=key, x_train=data.x_train, x_test=data.x_test, Z=Zs[:, :m, i], reg=reg)

        pred_train = model.predict_fn(x=data.x_train, labels=data.y_train, t=jnp.inf, train=True)
        train_loss = jnp.mean(jnp.square(data.y_train - pred_train))
        train_loss_inter.append(train_loss)
        train_acc = jnp.mean(jnp.sign(pred_train) == data.y_train)
        train_acc_inter.append(train_acc)

        pred_test = model.predict_fn(x=data.x_test, labels=data.y_train, t=jnp.inf, test=True)
        test_loss = jnp.mean(jnp.square(data.y_test - pred_test))
        test_loss_inter.append(test_loss)
        test_acc = jnp.mean(jnp.sign(pred_test) == data.y_test)
        test_acc_inter.append(test_acc)

        cross_loss, cross_acc = model.leave_one_out(labels=data.y_train)
        cross_inter.append(cross_loss)
        cross_acc_inter.append(cross_acc)

    train_loss_median = (jnp.median(jnp.array(train_loss_inter)))
    train_loss_55 = (jnp.percentile(jnp.array(train_loss_inter), q=60))
    train_loss_45 = (jnp.percentile(jnp.array(train_loss_inter), q=40))
    train_losses.append({'median': train_loss_median, '55': train_loss_55, '45': train_loss_45})

    train_acc_median = jnp.median(jnp.array(train_acc_inter))
    train_acc_55 = jnp.percentile(jnp.array(train_acc_inter), q=60)
    train_acc_45 = jnp.percentile(jnp.array(train_acc_inter), q=40)
    train_accs.append({'median': train_acc_median, '55': train_acc_55, '45': train_acc_45})

    test_loss_median = (jnp.median(jnp.array(test_loss_inter)))
    test_loss_55 = (jnp.percentile(jnp.array(test_loss_inter), q=60))
    test_loss_45 = (jnp.percentile(jnp.array(test_loss_inter), q=40))
    test_losses.append({'median': test_loss_median, '55': test_loss_55, '45': test_loss_45})

    test_acc_median = jnp.median(jnp.array(test_acc_inter))
    test_acc_55 = jnp.percentile(jnp.array(test_acc_inter), q=60)
    test_acc_45 = jnp.percentile(jnp.array(test_acc_inter), q=40)
    test_accs.append({'median': test_acc_median, '55': test_acc_55, '45': test_acc_45})

    cross_median = (jnp.median(jnp.array(cross_inter)))
    cross_55 = (jnp.percentile(jnp.array(cross_inter), q=60))
    cross_45 = (jnp.percentile(jnp.array(cross_inter), q=40))
    cross_losses.append({'median': cross_median, '55': cross_55, '45': cross_45})

    cross_acc_median = (jnp.median(jnp.array(cross_acc_inter)))
    cross_acc_55 = (jnp.percentile(jnp.array(cross_acc_inter), q=60))
    cross_acc_45 = (jnp.percentile(jnp.array(cross_acc_inter), q=40))
    cross_accs.append({'median': cross_acc_median, '55': cross_acc_55, '45': cross_acc_45})

    print(m)


sns.set(font_scale=1.3)
sns.set_style('whitegrid')

if args.dataset == 'MNIST':
    acc_low = 0.4
if args.dataset == 'CIFAR10':
    acc_low = 0.4

rect1 = patches.Rectangle((0, 0), 2, 2, facecolor='#ff7f78')
rect2 = patches.Rectangle((0, 0), 2, 2, facecolor='#439c74')

plt.plot(widths, [test_losses[i]['median'] for i in range(len(widths))], c='#ff7f78', alpha=0.4)
plt.fill_between(widths, [test_losses[i]['median'] for i in range(len(widths))], alpha=0.2, color='#ff7f78')

plt.plot(widths, [cross_losses[i]['median'] for i in range(len(widths))], c='#439c74', zorder=10)
plt.fill_between(x=widths, y1=[cross_losses[i]['55'] for i in range(len(widths))],
                 y2=[cross_losses[i]['45'] for i in range(len(widths))], alpha=0.2, color='#439c74', zorder=10)
plt.legend(labels=['Test Loss', 'LOO Loss'], loc='upper left',
               bbox_to_anchor=(0.0, 1.125), handles=(rect1, rect2), handler_map={
                patches.Rectangle: HandlerRect()},
               ncol=2, fancybox=False, shadow=False, frameon=False)

plt.xlabel('Width')
plt.ylabel('Loss')
plt.savefig(save_path + 'dd_losses' + args.dataset, dpi=500, bbox_inches="tight")
plt.show()
plt.close()

plt.plot(widths, [test_accs[i]['median'] for i in range(len(widths))], c='#ff7f78', alpha=0.4)
plt.fill_between(widths, y1=[test_accs[i]['median'] for i in range(len(widths))], y2=acc_low, alpha=0.2, color='#ff7f78')

plt.plot(widths, [cross_accs[i]['median'] for i in range(len(widths))], c='#439c74', zorder=10)
plt.fill_between(x=widths, y1=[cross_accs[i]['55'] for i in range(len(widths))],
                 y2=[cross_accs[i]['45'] for i in range(len(widths))], alpha=0.2, color='#439c74', zorder=10)

plt.legend(labels=['Test Accuracy', 'LOO Accuracy'], loc='upper left',
               bbox_to_anchor=(0.0, 1.125), handles=(rect1, rect2), handler_map={
                patches.Rectangle: HandlerRect()},
               ncol=2, fancybox=False, shadow=False, frameon=False)
plt.xlabel('Width')
plt.ylabel('Accuracy')
plt.savefig(save_path + 'dd_accs' + args.dataset, dpi=500, bbox_inches="tight")
plt.show()
plt.close()
