import pickle
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from scipy.stats import norm
import random
from scipy.stats import gaussian_kde
from scipy.interpolate import UnivariateSpline
from scipy.interpolate import interp1d

random.seed(13)

def quantile(T, tau1, tau2, L, mu, delta_x, delta_y, q, hyp_param):
    kappa = L/mu
    r1 = 102400*kappa/tau2
    r2 = 960*L*tau1**2*delta_x**2 + 3128*L*tau2**2*delta_y**2
    r3 = 320*tau1*delta_x**2 + 64*tau2*delta_y**2
    Q_qT = r1*(hyp_param/T + r2 + (r3/T)*np.log(1/(1-q)))
    return Q_qT

d = 30
center = 0
sigma = 1
c1 = 1
c2 = 1
M = np.random.normal(center, sigma, (d, d)) # randomly sampled Gaussian
K_tilde = (M + np.transpose(M))/2
K = 10*K_tilde/np.linalg.norm(K_tilde)
L = max(12*c1, 2*c2, np.abs(np.linalg.norm(K)))
mu = 8
delta_x = 1
delta_y = 1
tau1 = 1/(3*L)
tau2 = tau1/48
hyp_param = 12
T = 10000

# calculate quantiles 
start = 0
stop = 0.99
step = 0.0002
num_points = int((stop - start) / step) + 1  # Calculate the number of points

quantiles = np.linspace(start, stop, num_points)

Q_qT = [quantile(T = T, tau1 = tau1, tau2 = tau2, L = L, mu = mu, delta_x = delta_x, delta_y = delta_y, q = q, hyp_param = hyp_param) for q in quantiles]
Q_qT = np.array(Q_qT)

# Create an interpolation function for the CDF
# Extend to 0 at the left end and 1 at the right end
# quantiles = np.insert(quantiles, 0, 0)  # Insert 0 at the start of the quantile levels
# quantiles = np.append(quantiles, 1)  # Append 1 to the end

# Q_qT = np.insert(Q_qT, 0, Q_qT[0])  # Extend values to the left
# Q_qT = np.append(Q_qT, Q_qT[-1])  # Extend values to the right

### avg_metric portion ###

file_path_1 = f'./ncpl_result_data/SSAGDA_avg_metrics_tau1=0.027777777777777776.pkl'
with open(file_path_1, 'rb') as file:
    avg_metric = pickle.load(file)

# print(avg_metric)

# print(len(Q_qT), len(quantiles))

quantiles_avg_metric = np.quantile(avg_metric, quantiles)

Q_qT = ((np.max(quantiles_avg_metric) - np.min(quantiles_avg_metric))/(np.max(Q_qT) - np. min(Q_qT)))*(Q_qT - np.min(Q_qT)) + np.min(quantiles_avg_metric)

cdf_function_Q_qT = interp1d(Q_qT, quantiles, kind='linear', bounds_error=False, fill_value=(0,1))

# Values for which we want to evaluate the CDF (for plotting)
x_values = np.linspace(Q_qT.min(), Q_qT.max(), 500)
cdf_values_Q_qT = cdf_function_Q_qT(x_values)

# print(quantiles_avg_metric)
# quantiles_avg_metric = np.insert(quantiles_avg_metric, 0, quantiles_avg_metric[0])  # Extend values to the left
# quantiles_avg_metric = np.append(quantiles_avg_metric, quantiles_avg_metric[-1])  # Extend values to the right

# print(len(quantiles_avg_metric), len(quantiles))

cdf_function_avg_metric = interp1d(quantiles_avg_metric, quantiles, kind='linear', bounds_error=False, fill_value=(0,1))
x_values = np.linspace(quantiles_avg_metric.min(), quantiles_avg_metric.max(), 500)
cdf_values_avg_metric = cdf_function_avg_metric(x_values)

x_values = np.linspace(Q_qT.min(), Q_qT.max(), 500)

plt.rcParams['font.size'] = 20
# plt.rcParams['text.usetex'] = True
# Plot the CDF
plt.figure(figsize=(8, 4))
plt.plot(x_values, cdf_values_Q_qT, label='Theoretical CDF', color = 'blue')
# plt.scatter(Q_qT, quantiles, color='blue', zorder=5, label='Quantile Points Q_qT')  # Mark the quantile points
plt.plot(x_values, cdf_values_avg_metric, label='Empirical CDF', color = 'orange')
# plt.scatter(quantiles_avg_metric, quantiles, color='orange', zorder=5, label='Quantile Points averaged paths')  # Mark the quantile points
# plt.title('Cumulative Distribution Function (CDF) from Quantiles')
# plt.xlabel('Empirical' r' $ E[||\nabla _x f(x,y)||^2]$')
# plt.ylabel('Cumulative Probability')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

####################### PATHS ###########################
# file_path_1 = f'./ncpl_result_data/SSAGDA_x_iterate_paths_tau1=0.027777777777777776.pkl'
# file_path_2 = f'./ncpl_result_data/SSAGDA_y_iterate_paths_tau1=0.027777777777777776.pkl'
# file_path_3 = f'./ncpl_result_data/SSAGDA_x_iterate_paths_tau1=0.006944444444444444.pkl'
# file_path_4 = f'./ncpl_result_data/SSAGDA_y_iterate_paths_tau1=0.006944444444444444.pkl'

# file_path_5 = f'./ncpl_result_data/SSAGDA_x_grad_paths_tau1=0.027777777777777776.pkl'
# file_path_6 = f'./ncpl_result_data/SSAGDA_y_grad_paths_tau1=0.027777777777777776.pkl'
# file_path_7 = f'./ncpl_result_data/SSAGDA_x_grad_paths_tau1=0.006944444444444444.pkl'
# file_path_8 = f'./ncpl_result_data/SSAGDA_y_grad_paths_tau1=0.006944444444444444.pkl'

# #### Iterates 1 ####
# with open(file_path_1, 'rb') as file:
#     x_path_data = pickle.load(file)

# with open(file_path_2, 'rb') as file:
#     y_path_data = pickle.load(file)

# norm_sum_1 = x_path_data + y_path_data

# norm_sum_row_means_1 = np.mean(norm_sum_1, axis=1)

# # Upper and lower bounds
# lower_bound_1 = np.min(norm_sum_1, axis = 1)
# upper_bound_1 = np.max(norm_sum_1, axis = 1)


# #### Iterates 2 ####
# with open(file_path_3, 'rb') as file:
#     x_path_data = pickle.load(file)

# with open(file_path_4, 'rb') as file:
#     y_path_data = pickle.load(file)

# norm_sum_2 = x_path_data + y_path_data

# norm_sum_row_means_2 = np.mean(norm_sum_2, axis=1)

# # Upper and lower bounds
# lower_bound_2 = np.min(norm_sum_2, axis = 1)
# upper_bound_2 = np.max(norm_sum_2, axis = 1)


# #### Gradients 1 #####
# with open(file_path_5, 'rb') as file:
#     x_path_data = pickle.load(file)

# with open(file_path_6, 'rb') as file:
#     y_path_data = pickle.load(file)

# norm_sum_3 = x_path_data + y_path_data

# norm_sum_row_means_3 = np.mean(norm_sum_3, axis=1)

# # Upper and lower bounds
# lower_bound_3 = np.min(norm_sum_3, axis = 1)
# upper_bound_3 = np.max(norm_sum_3, axis = 1)

# #### Gradients 2 ####
# with open(file_path_7, 'rb') as file:
#     x_path_data = pickle.load(file)

# with open(file_path_8, 'rb') as file:
#     y_path_data = pickle.load(file)

# norm_sum_4 = x_path_data + y_path_data

# norm_sum_row_means_4 = np.mean(norm_sum_4, axis=1)

# # Upper and lower bounds
# lower_bound_4 = np.min(norm_sum_4, axis = 1)
# upper_bound_4 = np.max(norm_sum_4, axis = 1)

# # Transformation for visibility
# norm_sum_row_means_1 = np.log(norm_sum_row_means_1)
# norm_sum_row_means_2 = np.log(norm_sum_row_means_2)
# norm_sum_row_means_3 = np.log(norm_sum_row_means_3)
# norm_sum_row_means_4 = np.log(norm_sum_row_means_4)
# lower_bound_1 = np.log(lower_bound_1)
# upper_bound_1 = np.log(upper_bound_1)
# lower_bound_2 = np.log(lower_bound_2)
# upper_bound_2 = np.log(upper_bound_2)
# lower_bound_3 = np.log(lower_bound_3)
# upper_bound_3 = np.log(upper_bound_3)
# lower_bound_4 = np.log(lower_bound_4)
# upper_bound_4 = np.log(upper_bound_4)

# x_axis = np.arange(len(norm_sum_row_means_1)) 

# fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, figsize=(12,2))
# plt.rcParams['font.size'] = 16
# # plt.rcParams['text.usetex'] = True
# # First plot
# ax2.plot(x_axis, norm_sum_row_means_1, color = 'blue', linewidth = 2.5)
# ax2.fill_between(x_axis, lower_bound_1, upper_bound_1, color='blue', alpha=0.2, label= r'$\tau_1 = 1/(3\ell)$')
# ax2.plot(x_axis, norm_sum_row_means_2, color = 'orange', linewidth = 2.5)
# ax2.fill_between(x_axis, lower_bound_2, upper_bound_2, color='orange', alpha=0.2, label= r'$\tau_1 = 1/(12\ell)$')
# ax2.set_xlim(0,10000)
# ax2.set_title('Average path across iterates')
# ax2.set_ylabel(r'$\log(\|x_k\|^2 + \|y_k\|^2)$', fontsize = 16)
# ax2.set_xlabel('Iterations', fontsize = 16)

# # Second plot
# ax1.plot(x_axis, norm_sum_row_means_3, color = 'blue')
# ax1.fill_between(x_axis, lower_bound_3, upper_bound_3, color='blue', alpha=0.2, label=r'$\tau_1 = 1/(3\ell)$')
# ax1.plot(x_axis, norm_sum_row_means_4, color = 'orange')
# ax1.fill_between(x_axis, lower_bound_4, upper_bound_4, color='orange', alpha=0.2, label=r'$\tau_1 = 1/(12\ell)$')
# ax1.set_xlim(0,10000)
# ax1.set_title('Average path across gradients')
# ax1.set_ylabel(r'$\log(\|\nabla f(x_k,y_k)\|^2)$', fontsize = 16)
# ax1.set_xlabel('Iterations', fontsize = 16)

# #$ Common x-axis label
# # Add a common x-axis label at the center of the figure, below the subplots
# # fig.text(0.52, 0.15, 'Iterations', ha='center', va='center', fontsize=12)

# # Adjust layout to provide space for the central x-axis label
# fig.tight_layout(rect=[0, 0.04, 1, 1])  # Adjust the bottom parameter to make space for the x-label

# # Create a common legend
# lines, labels = ax1.get_legend_handles_labels()
# ax1.legend(lines, labels, loc='upper right')
# ax2.legend(lines, labels, loc='upper right')

# # Display the plot
# ax1.grid(True)
# ax2.grid(True)
# plt.show()

# # Plot mean
# plt.plot(x_axis, norm_sum_row_means_1, label='Mean of 25 Paths, tau1 = 1/(3L)', linewidth = 2.5, color='blue')
# plt.plot(x_axis, norm_sum_row_means_2, label='Mean of 25 Paths, tau1 = 1/(12L)', linewidth = 2.5, color='orange')

# # Plot confidence intervals
# plt.fill_between(x_axis, lower_bound_1, upper_bound_1, color='blue', alpha=0.2, label='pointwise max/min')
# plt.fill_between(x_axis, lower_bound_2, upper_bound_2, color='orange', alpha=0.2, label='pointwise max/min')

# plt.title('Average path across iterates for NCPL game')
# plt.xlabel('Iteration')
# plt.ylabel('log(||x_K||^2 + ||y_k||^2)')
# plt.xlim(left=0,right=10000)
# plt.legend()
# plt.grid(True)
# plt.tight_layout()
# plt.show()