import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from anpm import *
from synthetic_instances import *
from metrics import *

import argparse

sns.set_style("whitegrid")

plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
})

def generate_noise_type(noise_type: str, d: int, k: int, T: int, eta: float, U: np.ndarray):
    if noise_type == 'adversarial':
        return generate_adversarial_noise(d, k, T, eta, U)
    else:
        return generate_noise(d, k, T, eta)


if __name__ == "__main__":

    np.random.seed(0)

    parser = argparse.ArgumentParser()
    parser.add_argument('--d', type=int, default=1000)
    parser.add_argument('--k', type=int, default=10)
    parser.add_argument('--T', type=int, default=1000)
    parser.add_argument('--noise', type=str, default='adversarial', choices=['adversarial', 'stochastic'])

    args = parser.parse_args()
    d = args.d
    k = args.k
    T = args.T

    U = generate_eigenvectors(d)

    fig, ax = plt.subplots(1, 4, figsize=(12,3.5))
    for axes in ax.flat:
        axes.sharey(ax[0])

    for i in range(4):
        plt.setp(ax[i].get_xticklabels(), fontsize=14)
        plt.setp(ax[i].get_yticklabels(), fontsize=14)

    for a in ax:
        a.xaxis.set_ticks(np.arange(0, T+1, 50))

    for a in ax[1:]:
        a.tick_params(labelleft=False)

    ax[0].set_ylim(5e-5, 2)


    # =========================
    # Experiment 1: large eigengap
    # =========================

    lambda_1 = 1.
    lambda_k = 1.0
    lambda_kp1 = .99
    lambda_rest = .5
    lambdas = np.array([lambda_1]*(k-1) + [lambda_k] + [lambda_kp1] + [lambda_rest]*(d - k -1))
    A = generate_matrix(lambdas, U)

    X0 = generate_X0(d, k+1)

    eta = 1e-4
    Xi = generate_noise_type(args.noise, d, k+1, T, eta, U)

    beta_star = (lambda_kp1**2)/4
    betas = [0.0, beta_star/2, beta_star]
    labels = [r'$0$', r'$\beta^*/2$', r'$\beta^*$']

    colors = sns.color_palette("viridis", n_colors=len(betas) + 1)

    for i, beta in enumerate(betas):
        X_list = anpm(A, beta, T, X0, Xi)
        errors = [sin_thetak(U[:, :k], X[:,:k], k) for X in X_list]
        sns.lineplot(x=range(T+1), y=errors, ax=ax[0], label=labels[i], color=colors[i])

    X_list = anpm_tune(A, T, X0, Xi)
    errors = [sin_thetak(U[:, :k], X[:,:k], k) for X in X_list]
    sns.lineplot(x=range(T+1), y=errors, ax=ax[0], label=r'$\beta_t$', color="gold", linestyle='--')

    ax[0].set_title(r'$\Delta_k = 10^{-2},\; \xi = 10^{-4}$', fontsize=16)
    ax[0].set_yscale('log')
    ax[0].set_xlabel(r'$t$', fontsize=16)
    ax[0].set_ylabel(r'$\sin\theta_k(\mathbf{U}_k, \mathbf{X}_t)$', fontsize=16)
    ax[0].legend().remove()
    

    # =========================
    # Experiment 2: small eigengap
    # =========================

    lambda_kp1 = .999
    lambdas = np.array([lambda_1]*(k-1) + [lambda_k] + [lambda_kp1] + [lambda_rest]*(d - k -1))
    A = generate_matrix(lambdas, U)
    beta_star = (lambda_kp1**2)/4
    betas = [0.0, beta_star/2, beta_star]

    X0 = generate_X0(d, k+1)
    Xi = generate_noise_type(args.noise, d, k+1, T, eta, U)

    for i, beta in enumerate(betas):
        X_list = anpm(A, beta, T, X0, Xi)
        errors = [sin_thetak(U[:, :k], X[:,:k], k) for X in X_list]
        sns.lineplot(x=range(T+1), y=errors, ax=ax[1], label=labels[i], color=colors[i])

    X_list = anpm_tune(A, T, X0, Xi)
    errors = [sin_thetak(U[:, :k], X[:,:k], k) for X in X_list]
    sns.lineplot(x=range(T+1), y=errors, ax=ax[1], label=r'$\beta_t$', color="gold", linestyle='--')

    ax[1].set_title(r'$\Delta_k = 10^{-3},\; \xi = 10^{-4}$', fontsize=16)
    ax[1].set_yscale('log')
    ax[1].set_xlabel(r'$t$', fontsize=16)

    ax[1].legend(
        loc="lower right",
        ncol=1,
        fontsize=14,
        title=r'$\beta$',
        title_fontsize=14,
    )

    # =========================
    # Experiment 3: varying eigengap
    # =========================

    lambda_kp1_list = [0.9, 0.99, 0.999, 0.9999]
    deltas = [r"$10^{-1}$", r"$10^{-2}$", r"$10^{-3}$", r"$10^{-4}$"]

    Xi = generate_noise_type(args.noise, d, k, T, eta, U)
    X0 = generate_X0(d, k)

    colors = sns.color_palette("rocket_r", n_colors=len(lambda_kp1_list))

    for i, lambda_kp1 in enumerate(lambda_kp1_list):
        lambdas = np.array([lambda_1]*(k-1) + [lambda_k] + [lambda_kp1] + [lambda_rest]*(d - k - 1))
        A = generate_matrix(lambdas, U)
        beta_star = (lambda_kp1**2)/4

        X_list = anpm(A, beta_star, T, X0, Xi)
        errors = [sin_thetak(U[:, :k], X, k) for X in X_list]
        sns.lineplot(x=range(T+1), y=errors, ax=ax[2], label=deltas[i], color=colors[i])

    ax[2].set_title(r'$\beta=\beta^*,\; \xi=10^{-4}$', fontsize=16)
    ax[2].set_yscale('log')
    ax[2].set_xlabel(r'$t$', fontsize=16)

    ax[2].legend(
        loc="lower right",
        ncol=1,
        fontsize=14,
        title=r'$\Delta_k$',
        title_fontsize=14,
    )

    # =========================
    # Experiment 4: varying noise
    # =========================

    lambda_kp1 = .999
    beta_star = (lambda_kp1**2)/4
    lambdas = np.array([lambda_1]*(k-1) + [lambda_k] + [lambda_kp1] + [lambda_rest]*(d - k -1))
    A = generate_matrix(lambdas, U)

    eta_list = np.array([1e-2, 1e-3, 1e-4, 1e-5])
    etas_label = [r"$10^{-2}$", r"$10^{-3}$", r"$10^{-4}$", r"$10^{-5}$"]

    colors = sns.color_palette("mako_r", n_colors=len(eta_list))

    for i, eta in enumerate(eta_list):
        Xi = generate_noise_type(args.noise, d, k, T, eta, U)
        X_list = anpm(A, beta_star, T, X0, Xi)
        errors = [sin_thetak(U[:, :k], X, k) for X in X_list]
        sns.lineplot(x=range(T+1), y=errors, ax=ax[3], label=etas_label[i], color=colors[-i-1])

    ax[3].set_title(r'$\Delta_k = 10^{-3},\; \beta=\beta^*$', fontsize=16)
    ax[3].set_yscale('log')
    ax[3].set_xlabel(r'$t$', fontsize=16)

    ax[3].legend(
        loc="lower right",
        ncol=1,
        fontsize=14,
        title=r'$\xi$',
        title_fontsize=14,
    )

    plt.tight_layout()
    plt.savefig(f'plots/anpm_experiment_{args.noise}.pdf')
