""" This code is used generate the plot in Figure 1.
Before executing this file, please execute the following to generate the data in the plot:
python WMC_synthetic/240519_compute_geometric_Lambdavlambda.py

In this code, we have
N1 = number of samples generated from first manifold
N2 = number of samples generated from second manifold """

import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams

rcParams['font.family'] = 'Times New Roman'
rcParams['font.size'] = 28
rcParams['mathtext.fontset'] = 'stix'

sys.path.insert(0, './')

from manifoldGen import manifoldGen
# from utils import *

# np.savez(f'240519-variables-L-vs-l_N1_{N1}_N2_{N2}.npz', inradius=inradius, gamma_W=gamma_W, ineq=ineq, LL=LL, LU=LU, lambd_list=lambd_list, N1=N1, N2=N2, seed_list=seed_list)

N1_N2_list = [[120, 40], [150, 40], [160, 40], [200, 40]]

for N1, N2 in N1_N2_list:
    print(f'N1={N1}, N2={N2}')

    # Load saved variables
    data = np.load(f'WMC_synthetic/240519-variables-L-vs-l_N1_{N1}_N2_{N2}.npz')
    inradius = data['inradius']
    gamma_W = data['gamma_W']
    ineq = data['ineq']
    LL = data['LL']
    LU = data['LU']
    lambd_list = data['lambd_list']
    N1 = data['N1']
    N2 = data['N2']
    seed_list = data['seed_list']


    # Calculate mean and standard deviations of output metrics
    r_mean = np.zeros(len(lambd_list))
    r_std = np.zeros(len(lambd_list))
    LL_mean = np.zeros(len(lambd_list))
    LL_std = np.zeros(len(lambd_list))
    LU_mean = np.zeros(len(lambd_list))
    LU_std = np.zeros(len(lambd_list))
    gamma_W_mean = np.zeros(len(lambd_list))
    gamma_W_std = np.zeros(len(lambd_list))
    ineq_mean = np.zeros(len(lambd_list))
    ineq_std = np.zeros(len(lambd_list))
    for l in range(len(lambd_list)):
        r_mean[l] = np.mean(inradius[l,:])
        r_std[l] = np.std(inradius[l,:])/10
        LL_mean[l] = np.mean(LL[l,:])
        LL_std[l] = np.std(LL[l,:])
        LU_mean[l] = np.mean(LU[l,:])
        LU_std[l] = np.std(LU[l,:])
        gamma_W_mean[l] = np.mean(gamma_W[l,:])
        gamma_W_std[l] = np.std(gamma_W[l,:])
        ineq_mean[l] = np.mean(ineq[l,:])
        ineq_std[l] = np.std(ineq[l,:])


    colors = sns.color_palette("deep")


    # Generate plots
    linestyle = {"linestyle":"-", "linewidth":2, "markeredgewidth":1, "elinewidth":1, "capsize":4}
    # font = {#'weight' : 'bold',
            # 'size'   : 22}
    # plt.rc('font', **font)
    fig,ax = plt.subplots(figsize=(8, 8))
    # set tight layout
    fig.tight_layout(pad=4)
    # plt.rc('font', **font)
    plt.plot(lambd_list,lambd_list,color=colors[3])
    plt.errorbar(lambd_list,LL_mean,LL_std,color=colors[2],**linestyle)
    plt.errorbar(lambd_list,LU_mean,LU_std,color=colors[0],**linestyle)
    plt.xlabel("$\lambda$")
    if N1 == 120:
        plt.legend(["$\lambda$","$\lambda^l$","$\lambda^u$"])
    plt.title(f"$N_1$={N1}, $N_2$={N2}")


    #Save plots
    filename = f'./output_WMCsynthetic/Lambda_FixedN{N1}_{N2}_r{len(seed_list)}.pdf'

    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    plt.savefig(filename, bbox_inches='tight')
    plt.close()
