import numpy as np
import matplotlib.pyplot as plt
import random
import os
import pickle
import time

# This code is for the NCPL games experiment: f(x,y) = c1[||x||^2 + sin(3*sqrt(||x||^2 + 1))] + <Kx,y> - c2[||y||^2 + 3*sin^2(||y||)]
# Defining the problem here.
random.seed(13)
d = 30
center = 0
sigma = 1
c1 = 1
c2 = 1
M = np.random.normal(center, sigma, (d, d)) # randomly sampled Gaussian
K_tilde = (M + np.transpose(M))/2
K = 10*K_tilde/np.linalg.norm(K_tilde)

def SSAGDA_paths(x0, y0, z0, delta, K, T, c1, c2, sim_time, is_save_data):
    avg_metric = []

    L = max(12*c1, 8*c2, np.abs(np.linalg.norm(K)))
    tau_1 = 1/(3*L)
    tau_2 = tau_1/48
    p = 2*L
    mu = 8*c2
    beta = mu*tau_2/1600
    kappa = L/mu

    directory = './ncpl_result_data'
    os.makedirs(directory, exist_ok=True)

    for s in range(sim_time):
        # x_sim = np.zeros(T)
        # y_sim = np.zeros(T)
        metric_sim = np.zeros(T)

        for iter in range(T):
            w_x = np.random.normal(0, delta, len(x0))
            grad_x = c1*(2*x0 + np.cos(3*np.sqrt(np.linalg.norm(x0)**2 + 1))*(3*x0*(np.linalg.norm(x0)**2 + 1)**(-1/2))) + np.dot(K, y0) + w_x
            x0 -= tau_1 * (grad_x + p * (x0 - z0))
            
            w_y = np.random.normal(0, delta, len(y0))
            grad_y = np.dot(K, x0) - c2*(2*y0 + 6*np.sin(np.linalg.norm(y0))*np.cos(np.linalg.norm(y0))*y0/np.linalg.norm(y0)**(1/2)) + w_y
            y0 += tau_2 * grad_y
            z0 += beta * (x0 - z0)

            # print(np.linalg.norm(x0)**2,np.linalg.norm(y0)**2)

            # x_sim[iter] = np.linalg.norm(grad_x)**2
            # y_sim[iter] = np.linalg.norm(grad_y)**2
            metric_sim[iter] = np.linalg.norm(grad_x)**2 + kappa*np.linalg.norm(grad_y)**2
            # x_sim[iter] = np.linalg.norm(x0)**2
            # y_sim[iter] = np.linalg.norm(y0)**2

        if is_save_data:
            # if len(all_x) == 0:
            #     all_x = x_sim.reshape(-1, 1)  # Reshape for a single column
            #     all_y = y_sim.reshape(-1, 1) 
            # else:
            #     all_x = np.column_stack((all_x, x_sim.reshape(-1, 1)))
            #     all_y = np.column_stack((all_y, y_sim.reshape(-1, 1)))
            avg_metric.append(np.mean(metric_sim))

    if is_save_data:
        # all_x_samples_file = os.path.join(directory, f'SSAGDA_x_grad_paths_tau1={tau_1}.pkl')
        # with open(all_x_samples_file, 'wb') as fp:
        #     pickle.dump(all_x, fp)
        # all_y_samples_file = os.path.join(directory, f'SSAGDA_y_grad_paths_tau1={tau_1}.pkl')
        # with open(all_y_samples_file, 'wb') as fp:
        #     pickle.dump(all_y, fp)
        metric_file = os.path.join(directory, f'SSAGDA_avg_metrics_tau1={tau_1}.pkl')
        with open(metric_file, 'wb') as fp:
            pickle.dump(avg_metric, fp)

low = -20
high = 20
x0 = np.random.uniform(low, high, d)
y0 = np.random.uniform(low, high, d)
z0 = x0
delta = 1
T = 10000
is_save_data = True
sim_time = 1000

start_time = time.time()
SSAGDA_paths(x0, y0, z0, delta, K, T, c1, c2, sim_time, is_save_data)
end_time = time.time()  # End timing
elapsed_time = end_time - start_time  # Calculate elapsed time
print(f"The algorithm took {elapsed_time:.2f} seconds to run.")