from utils import *
import sys
from math import inf
import time
from tabulate import tabulate
from sklearn.cluster import KMeans

for trial in range(10):
    m = 7
    # K = 10
    M = 1.0
    a = 3.0
    b = 3.0
    low = 5.0
    high = 8.0
    # high = 10.0
    
    mu0 = generate_mu(m, low, high, a, b)
    mu1 = generate_mu(m, low, high, a, b)
    mu2 = generate_mu(m, low, high, a, b)
    mu3 = generate_mu(m, low, high, a, b)
    log = {'mu0' : mu0, 'mu1' : mu1, 'mu2' : mu2, 'mu3' : mu3}
    N_tot = 500
    var = 3.0
    # Generate attacker data
    theta_1, theta_2 = generate_theta_normal(mu0, var * np.eye(m), mu1, var * np.eye(m), mu2, var * np.eye(m), mu3, var * np.eye(m), N_tot)
    # tot_theta = np.concatenate([theta_1, theta_2], axis=2)
    full_theta = np.concatenate([theta_1, theta_2], axis=2)
    print("Mu0 : ", mu0)
    print("Mu1 : ", mu1)
    print("Mu2 : ", mu2)
    print("Mu3 : ", mu3)

    all_opt = []
    w = np.ones(N_tot + 1)
    w[-1] = 0
    xi = 1e4
    N = 25
    tot_theta = np.array(full_theta)
    tot_theta = tot_theta.reshape(N_tot,-1)
    kmeans = KMeans(n_clusters=N).fit(tot_theta)
    cluster_cent = kmeans.cluster_centers_
    cluster_cent = cluster_cent.reshape(N, m, 4)
    Y = kmeans.predict(tot_theta)
    s = np.zeros(N+1)
    for i in range(len(Y)):
        s[Y[i]] += 1
    # all_losses.append(kmeans.inertia_)
    for K in [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]:
        A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, cluster_cent)
        sta = time.perf_counter()
        z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, s, xi)
        opt = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
        all_opt.append(opt)
        print(all_opt)

    # A, b, C, d, tL, tU = compute_params(m, K, N_tot, numerator, denominator, full_theta)
    # w = np.ones(N_tot + 1)
    # w[-1] = 0
    # tot_z_dro = FCP_DRO(m, K, N_tot, N_tot, M, tL, tU, A, b, C, d, w, xi)
    # tot_opt = utility_robust(full_theta, numerator, denominator, m, tot_z_dro, N_tot, xi, w)
    # print(tot_opt)

    # all_opt = []
    # all_times = []
    # all_losses = []
    # N = 2
    # while (N <= 20):
    #     tot_theta = np.array(full_theta)
    #     tot_theta = tot_theta.reshape(N_tot,-1)
    #     kmeans = KMeans(n_clusters=N).fit(tot_theta)
    #     all_losses.append(kmeans.inertia_)
    #     cluster_cent = kmeans.cluster_centers_
    #     cluster_cent = cluster_cent.reshape(N, m, 4)
    #     Y = kmeans.predict(tot_theta)
    #     s = np.zeros(N+1)
    #     for i in range(len(Y)):
    #         s[Y[i]] += 1
    #     A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, cluster_cent)
    #     sta = time.perf_counter()
    #     z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, s, xi)
    #     opt = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
    #     fin = time.perf_counter()
    #     # diff = np.sqrt(sum((z_dro - tot_z_dro)**2))
    #     print(tot_opt)
    #     diff = 100 * (tot_opt - opt) / tot_opt
    #     print("N : ", N, " diff : ", diff)
    #     all_opt.append(opt)
    #     all_times.append(fin-sta)
    #     N += 1

    results_location = './simulations/SSG_opt_value_K_{}_trial_{}.npy'.format(m, trial)
    log['all_opt'] = all_opt
    # log['all_times'] = all_times
    # log['tot_opt'] = tot_opt
    # log['all_losses'] = all_losses
    np.save(results_location, log) 
