import numpy as np

from depca import *
from datasets import *
from metrics import sin_thetak
from gossip import compute_omega

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D

sns.set_style("whitegrid")

plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
})


if __name__ == "__main__":
    
    np.random.seed(0)

    fig, ax = plt.subplots(1, 2, figsize=(12,3.5))

    # ======================
    # Experiment on FLamby
    # ======================

    print("FLamby Experiment")

    n = 4
    d = 13
    k = 5
    T = 200

    A, W, A_mean = load_fed_heart_disease()
    eigs, U = np.linalg.eigh(A_mean)
    Uk = U[:, -k:]

    lambda_k1 = eigs[-k-1]
    beta_star = lambda_k1**2 / 4

    X0 = np.random.randn(d, k+1)
    
    omega = compute_omega(W)

    

    Ls = [5, 10]
    algs = ['DePM', 'DeEPCA', r'ADePM $\beta=\beta^*$', r'ADePM $\beta=\beta_t$']
    colors = sns.color_palette("inferno", n_colors=len(algs))
    linestyles = [':', '-']

    for i in range(len(algs)):
        for j in range(len(Ls)):
            if i == 0:
                X_list = DePM(A, T, X0[:, :k], W, omega, Ls[j])
            elif i == 1:
                X_list = DeEPCA(A, T, X0[:, :k], W, omega, Ls[j])
            elif i == 2:    
                X_list = ADePM(A, T, beta_star, X0[:, :k], W, omega, Ls[j])
            else:
                X_list = ADePM_tune(A, T, X0, W, omega, Ls[j])
            errors = [[sin_thetak(Uk, X[i][:,:k], k) for X in X_list] for i in range(n)]
            sin_mean = np.mean(errors, axis=0)
            sns.lineplot(x=np.arange(len(sin_mean)), y=sin_mean, label=f'{algs[i]}, L={Ls[j]}', color=colors[i],  linestyle=linestyles[j], ax =ax[0])

    ax[0].set_yscale('log')
    ax[0].set_xlabel(r'$t$', fontsize=16)
    ax[0].set_ylabel(r'$\frac{1}{n}\sum \sin\theta_k(\mathbf{U}_k, \mathbf{X}_{i,t})$', fontsize=16)
    ax[0].set_xticks(np.arange(0, T+1, 50))
    ax[0].set_yticks([1e-8, 1e-6, 1e-4, 1e-2, 1])
    plt.setp(ax[0].get_xticklabels(), fontsize=14)
    plt.setp(ax[0].get_yticklabels(), fontsize=14)

    # Legends
    alg_handles = [
        Line2D([0], [0], color=colors[i], lw=2, linestyle='-',
            label=algs[i])
        for i in range(len(algs))
    ]

    L_handles = [
        Line2D([0], [0], color='black', lw=2, linestyle=linestyles[j],
            label=rf'$L={Ls[j]}$')
        for j in range(len(Ls))
    ]

    legend1 = ax[0].legend(
        handles=alg_handles,
        loc="upper right",
        fontsize=14,
        ncol=2
    )

    legend2 = ax[0].legend(
        handles=L_handles,
        loc="lower left",
        fontsize=14,
    )

    ax[0].add_artist(legend1)
    ax[0].add_artist(legend2)


    # ==========================
    # Experiment on Ego-Facebook
    # ==========================

    print("Ego-Facebook Experiment")

    n = 50
    d = n
    k = 5
    T = 200

    A, W, A_mean = load_ego_facebook(n)
    eigs, U = np.linalg.eigh(A_mean)
    Uk = U[:, -k:]

    lambda_k1 = eigs[-k-1]
    beta_star = lambda_k1**2 / 4

    X0 = np.random.randn(d, k+1)
    
    omega = compute_omega(W)
    

    Ls = [20, 40]
    algs = ['DePM', 'DeEPCA', r'ADePM $\beta=\beta^*$', r'ADePM $\beta=\beta_t$']
    colors = sns.color_palette("inferno", n_colors=len(algs))
    linestyles = [':', '-']


    for i in range(len(algs)):
        for j in range(len(Ls)):
            if i == 0:
                X_list = DePM(A, T, X0[:, :k], W, omega, Ls[j])
            elif i == 1:
                X_list = DeEPCA(A, T, X0[:, :k], W, omega, Ls[j])
            elif i == 2:    
                X_list = ADePM(A, T, beta_star, X0[:, :k], W, omega, Ls[j])
            else:
                X_list = ADePM_tune(A, T, X0, W, omega, Ls[j])
            errors = [[sin_thetak(Uk, X[i][:,:k], k) for X in X_list] for i in range(n)]
            sin_mean = np.mean(errors, axis=0)
            sns.lineplot(x=np.arange(len(sin_mean)), y=sin_mean, label=f'{algs[i]}, L={Ls[j]}', color=colors[i],  linestyle=linestyles[j], ax =ax[1])

    ax[1].set_xlabel(r'$t$', fontsize=16)
    ax[1].set_yscale('log')
    ax[1].set_xticks(np.arange(0, T+1, 50))
    ax[1].set_yticks([1e-6, 1e-4, 1e-2, 1])
    plt.setp(ax[1].get_xticklabels(), fontsize=14)
    plt.setp(ax[1].get_yticklabels(), fontsize=14)

    L_handles = [
        Line2D([0], [0], color='black', lw=2, linestyle=linestyles[j],
            label=rf'$L={Ls[j]}$')
        for j in range(len(Ls))
    ]

    legend2 = plt.legend(
        handles=L_handles,
        loc="lower left",
        fontsize=14,
    )

    ax[1].add_artist(legend2)

    plt.tight_layout()
    plt.savefig("plots/adepm_experiment.pdf")    
    
