import jax.numpy as jnp
from jax import random
from jax.config import config
import os

import matplotlib
matplotlib.use('TkAgg')
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import argparse
from plot_utils import HandlerRect

from kernels import RandomFeatures, NeuralTangent, NeuralTangentConv
from data import MNIST, RandomData, CIFAR10
from utils import acc_multi

parser = argparse.ArgumentParser()
parser.add_argument('--n_train_s', default=10, type=int)
parser.add_argument('--n_train_e', default=1000, type=int)
parser.add_argument('--n_test', default=1000, type=int)
parser.add_argument('--classes', default=10, type=int)
parser.add_argument('--num', default=10, type=int)
parser.add_argument('--runs', default=1, type=int)
parser.add_argument('--dataset', default='MNIST', type=str)
parser.add_argument('--reg', default=0, type=float)
parser.add_argument('--model', default='NTK', type=str)
parser.add_argument('--depth', default=5, type=int)
parser.add_argument('--width', default=None, type=int)

args = parser.parse_args()

save_path = os.path.dirname(os.path.realpath(__file__)) + '/store/plots/'
config.update("jax_enable_x64", True)
runs = args.runs
n_test = args.n_test
n_trains = jnp.linspace(args.n_train_s, args.n_train_e, num=args.num)
n_trains = [int(n) for n in n_trains]

if args.reg == 0:
    reg = None
else:
    reg = args.reg

if args.dataset == 'MNIST' or args.dataset == 'FashionMNIST':
    dim = 784
if args.dataset == 'CIFAR10':
    dim = 3072

if args.classes == 10:
    one_hot = True
    binary = False
else:
    one_hot = False
    binary = True

if args.model == 'ConvNTK':
    flat = False
else:
    flat = True

K = args.classes

key = random.PRNGKey(200)
data_keys = random.split(key, runs)


train_losses, train_accs, test_losses, test_accs, cross_losses, cross_accs = [], [], [], [], [], []
train_losses_v, train_accs_v, test_losses_v, test_accs_v, cross_losses_v, cross_accs_v = [], [], [], [], [], []

for n in n_trains:
    tr_l_inter, tr_a_inter, te_l_inter, te_a_inter, c_l_inter, c_a_inter = [], [], [], [], [], []
    for j in range(runs):
        # Load the dataset
        if args.dataset == 'MNIST':
            data = MNIST(n_train=n, n_test=n_test, binary=binary, flat=flat, permute_key=data_keys[j], one_hot=one_hot,
                         scale_dim=None)
        if args.dataset == 'CIFAR10':
            data = CIFAR10(n_train=n, n_test=n_test, binary=binary, flat=flat, one_hot=one_hot, permute_key=data_keys[j],
                           scale_dim=None)

        # Load the model
        if args.model == 'RF':
            m = 50
            Z = random.normal(key=key, shape=(dim, m))
            model = RandomFeatures(dim=dim, m=m, key=key, x_train=data.x_train, x_test=data.x_test, reg=reg, Z=Z)
        if args.model == 'NTK':
            model = NeuralTangent(dim=dim, x_train=data.x_train, x_test=data.x_test, depth=args.depth, reg=reg, classes=K)
        if args.model == 'ConvNTK':
            strides = [[2, 2] for _ in range(args.depth)]
            filter_shapes = [[3, 3] for _ in range(args.depth)]
            widths = [100 for _ in range(args.depth)]
            model = NeuralTangentConv(dim=dim, widths=widths, filter_shapes=filter_shapes, strides=strides,
                                      x_train=data.x_train, x_test=data.x_test, depth=args.depth, reg=None, classes=10)

        # Get the predictions and statistics
        pred_train = model.predict_fn(x=data.x_train, labels=data.y_train, t=jnp.inf, train=True)
        train_loss = 1/n * jnp.sum(jnp.square(data.y_train - pred_train))
        tr_l_inter.append(train_loss)
        if binary:
            train_acc = jnp.mean(jnp.sign(pred_train) == data.y_train)
        else:
            train_acc = acc_multi(data.y_train, pred_train)
        tr_a_inter.append(train_acc)

        pred_test = model.predict_fn(x=data.x_test, labels=data.y_train, t=jnp.inf, test=True)
        test_loss = 1/n_test * jnp.sum(jnp.square(data.y_test - pred_test))
        te_l_inter.append(test_loss)
        if binary:
            test_acc = jnp.mean(jnp.sign(pred_test) == data.y_test)
        else:
            test_acc = acc_multi(data.y_test, pred_test)

        te_a_inter.append(test_acc)

        cross_loss, cross_acc = model.leave_one_out(labels=data.y_train)
        c_l_inter.append(cross_loss)
        c_a_inter.append(cross_acc)
        print(n)

    train_losses.append(jnp.mean(jnp.array(tr_l_inter)))
    train_losses_v.append(jnp.std(jnp.array(tr_l_inter)))

    train_accs.append(jnp.mean(jnp.array(tr_a_inter)))
    train_accs_v.append(jnp.std(jnp.array(tr_a_inter)))

    test_losses.append(jnp.mean(jnp.array(te_l_inter)))
    test_losses_v.append(jnp.std(jnp.array(te_l_inter)))

    test_accs.append(jnp.mean(jnp.array(te_a_inter)))
    test_accs_v.append(jnp.std(jnp.array(te_a_inter)))
    print(jnp.mean(jnp.array(te_a_inter)))

    cross_losses.append(jnp.mean(jnp.array(c_l_inter)))
    cross_losses_v.append(jnp.std(jnp.array(c_l_inter)))

    cross_accs.append(jnp.mean(jnp.array(c_a_inter)))
    cross_accs_v.append(jnp.std(jnp.array(c_a_inter)))

sns.set(font_scale=1.3)
sns.set_style('whitegrid')

if args.dataset == 'MNIST':
    loss_low = 0.05
    acc_low = 0.6
if args.dataset == 'CIFAR10':
    loss_low = 0.6
    acc_low = 0.2
#plt.plot(n_trains, train_losses, c='#ffba7f')


rect1 = patches.Rectangle((0, 0), 2, 2, facecolor='#ff7f78')
rect2 = patches.Rectangle((0, 0), 2, 2, facecolor='#439c74')

plt.plot(n_trains, test_losses, c='#ff7f78', alpha=0.4)
plt.fill_between(x=n_trains, y1=[test_losses[i] for i in range(len(n_trains))], y2=loss_low,
                 alpha=0.2, color='#ff7f78')

plt.plot(n_trains, cross_losses, c='#439c74', zorder=10)
plt.fill_between(x=n_trains, y1=[cross_losses[i] - cross_losses_v[i] for i in range(len(n_trains))],
                 y2=[cross_losses[i] + cross_losses_v[i] for i in range(len(n_trains))], alpha=0.2, color='#439c74',
                 zorder=10)
plt.legend(labels=['Test Loss', 'LOO Loss'], loc='upper left', bbox_to_anchor=(0.0, 1.125),
                ncol=2, fancybox=False, shadow=False, frameon=False, handles=(rect1, rect2), handler_map={
                patches.Rectangle: HandlerRect()})

plt.xlabel('Number of Samples')
plt.ylabel('Loss')
plt.savefig(save_path + 'sample_losses' + str(args.dataset), dpi=500, bbox_inches="tight")
plt.show()
plt.close()

#plt.plot(n_trains, train_accs, c='#ffba7f')

plt.plot(n_trains, test_accs, c='#ff7f78', alpha=0.4)
plt.fill_between(x=n_trains, y1=[test_accs[i] for i in range(len(n_trains))], y2=acc_low,
                 alpha=0.2, color='#ff7f78')

plt.plot(n_trains, cross_accs, c='#439c74', zorder=10)
plt.fill_between(x=n_trains, y1=[cross_accs[i] - cross_accs_v[i] for i in range(len(n_trains))],
                 y2=[cross_accs[i] + cross_accs_v[i] for i in range(len(n_trains))], alpha=0.2, color='#439c74',
                 zorder=10)

plt.legend(labels=['Test Accuracy', 'LOO Accuracy'], loc='upper left',
               bbox_to_anchor=(0.0, 1.125),
               ncol=2, fancybox=False, shadow=False, frameon=False, handles=(rect1, rect2), handler_map={
                patches.Rectangle: HandlerRect()})
plt.xlabel('Number of Samples')
plt.ylabel('Accuracy')
plt.savefig(save_path + 'sample_accs' + str(args.dataset), dpi=500, bbox_inches="tight")
plt.show()
plt.close()
