from utils import *
import sys
from math import inf
import os
import pandas as pd
from sklearn.cluster import KMeans
import pickle
import random
from copy import deepcopy
random.seed(0)

dset = pd.read_csv('./data/NYC_82341_59_5_1.data')
dset = dset.to_numpy()
full_data = []
U = []
for i in range(dset.shape[0]):
    temp = dset[i][0].split()
    if(i == 0):
        full_data.append([int(j) for j in temp])
    elif(i == 1):
        U.append([float(j) for j in temp])
        # full_data.append([0.0 for j in temp])
        full_data.append([float(j) for j in temp])
    else:
        full_data.append([float(temp[j]) for j in range(len(temp))])
        # full_data.append([float(temp[j]) - U[0][j] for j in range(len(temp))])
full_data = np.array(full_data)
# print(np.max(np.exp(full_data[2:,])))
# print(full_data[1:,0])
# print(np.max(full_data[2:,0]))
# print(np.min(full_data[2:,0]))
# print(U[0][:10])
# for i in range(10):
#     print(full_data[4,i] - U[])
# print(full_data[:,1])

trans_data = full_data.T
# Filter data
# count = 0
# upper = 1.0
# lower = -1.0
for upper in [2.0]:
    for lower in [0.0]:
        if(upper == lower):
            continue
        count = 0
        filtered_data = []
        for i in range(len(trans_data)):
            u = trans_data[i][1]
            v_max = np.max(trans_data[i,2:])
            if (v_max - u >= lower and v_max - u <= upper):
                filtered_data.append(trans_data[i])
                count += 1
        print(count)
        filtered_data = np.array(filtered_data)
        # Shuffle full_data to form train-test split
        N_train = int(0.8 * count)
        random.shuffle(filtered_data)
        # print(filtered_data[0])
        full_data = filtered_data.T

        # Fit linear cost model
        m = 59
        Bjs = np.zeros(m)
        Aijs = np.zeros((count-N_train,m))
        for j in range(m):
            Bjs[j] = np.min(full_data[j+2,:])
        
        for i in range(count-N_train):
            for j in range(m):
                Aijs[i][j] = full_data[j+2,i + N_train] - Bjs[j]

        test_data = full_data[:,N_train:]
        full_data = full_data[:,:N_train]

        features = full_data[1:].T
        weights = np.array(full_data[0])
        test_features = test_data[1:].T 
        test_weights = np.array(test_data[0])
        # Aijs_train = Aijs[:N_train]
        # Aijs_test = Aijs[N_train:]
        # print(np.max(features))
        # print(np.min(features))
        N_act = len(weights)
        print(N_act)
        # scale = 1.0
        # features = np.exp(-features/scale)
        # test_features = np.exp(-test_features/scale)
        features = features 
        test_features = test_features 

        # # log = {}
        # # all_losses = []
        # # log_file = './kmeans/exp{}_scale.npy'.format(scale)
        # # for N in range(50,10001,50):
        N = 50
        kmeans_loc = './kmeans/filtered_kmeans_{}clusters_{}upper_{}lower_{}train'.format(N, upper, lower, N_train)
        # kmeans_loc = './kmeans/another_kmeans_{}clusters'.format(N)
        if(os.path.isfile(kmeans_loc)):
            print("Using Saved K-Means")
            kmeans = pickle.load(open(kmeans_loc, 'rb'))
        else:
            print("Running new K-Means")
            kmeans = KMeans(n_clusters=N).fit(features,sample_weight=weights)
            pickle.dump(kmeans, open(kmeans_loc, 'wb'))
            print("Saved K-Means")

        #     # print ("N : ",N," Inertia : ",kmeans.inertia_)
        # #     all_losses.append(kmeans.inertia_)
        # # log['losses'] = all_losses
        # # np.save(log_file, log)
        # # print(all_losses)

        # # # # print(kmeans.inertia_)
        C = kmeans.cluster_centers_

        # To compute fractions


        # print(C[0])
        # L = []
        # for i in range(N):
        #     L.append(np.max(C[i][1:]) - C[i][0])
        # L.sort()
        # print(L)
        # Update new counts
        Y = kmeans.predict(features)
        s = np.zeros(N+1)
        for i in range(len(Y)):
            # s[Y[i]] += weights[Y[i]]
            s[Y[i]] += weights[i]

        print(s)
        # print(C)
        Aijs_train = np.zeros((N, m))
        for i in range(len(C)):
            # C[i,:] = C[i,:] - np.max(C[i,:])
            Aijs_train[i,:] = C[i,1:] - Bjs
            # C[i,:] = C[i,:] - C[i,0]

        # Check overflow/underflow
        Bijs_train = np.zeros((N, m))
        for i in range(len(C)):
            Bijs_train[i,:] = Bjs - C[i,0]

        print(Aijs_train[0])
        print(Bijs_train[0])

        Bijs_test = np.zeros((len(test_features), m))
        for i in range(len(test_features)):
            # test_features[i,:] = test_features[i,:] - np.max(test_features[i,:])
            # test_features[i,:] = test_features[i,:] - test_features[i,0]
            Bijs_test[i,:] = Bjs - test_features[i,0]


        m = 59
        # L = []
        # for i in range(1,m+1):
        #     L.append(np.mean(C[:,i]))
        # L.sort()
        # print(L)

        # # Prepare Constants
        # U = np.exp(C[:,0])
        # V = np.exp(C[:,1:])
        # U_test = np.exp(test_features[:,0])
        # V_test = np.exp(test_features[:,1:])
        # print("Range U")
        # print(np.max(U))
        # print(np.min(U))
        # print("Range V")
        # print(np.max(V))
        # print(np.min(V))

        # tL, tU = upper_lower_t(V, U)
        # print("Range t")
        # print(np.max(tU))
        # print(np.min(tL))
        # print(tL)



# # print(U[0])
# # Prepare feature vectors
# # N_train = 60000
# # X = np.zeros((N*m,N*m + 1))
# X = np.zeros((N*m,N*m + m))
# for i in range(N):
#     for j in range(m):
#         X[m * i + j][m * i + j] = 1 # corresponding cluster
#         X[m * i + j][N * m + j] = 1
# # Util = C[:,1:]
# Util = features[:,1:]
# # print(Util)
# y = np.zeros(N*m)
# for i in range(N):
#     for j in range(m):
#         y[i * m + j] = Util[i][j]
# # print(X)
# # print(y)
# t = [s[i] for i in range(N) for j in range(m)]
# reg = Ridge(alpha=1.0, solver='lbfgs', positive=True)
# reg.fit(X, y, [s[i] for i in range(N) for j in range(m)])

# # print(reg.coef_)
# # print(reg.intercept_)

# w = np.array(reg.coef_)
# bias = reg.intercept_
# a_js = [w[j * m:(j + 1) * m] for j in range(N)]
# # b_js = np.array(reg.coef_[-1]) + reg.intercept_ 
# b_js = np.array([V[i] - a_js[i] for i in range(len(V))])

# a_jstest = [a_js[Y_2[i]] for i in range(len(V_test))]
# b_jstest = np.array([V_test[i] - a_js[Y_2[i]] for i in range(len(V_test))])

# # print(a_js)
# # print(b_js)
# # print(Util)
# # print(t)

        theta = [Aijs_train, Bijs_train, np.zeros(N)]
        test_theta = [Aijs, Bijs_test, np.zeros(len(test_features))]
# print(b_jstest)
# print(b_js)
# print(test_theta)
# a = np.ones(m)
# a [10] = 0
# values_DRO = FCP_values(np.ones(m), test_theta)
# print(values_DRO)
# print(values_DRO.shape)
# avg_d, std_d = weighted_avg_and_std(values_DRO, test_weights)
# print(avg_d)
# print(std_d)

# print(a_js)
        K = 10
        A, b, C, d, tL, tU = compute_params(m, K, N, FCP_numerator, FCP_denominator, theta)
        print("Range t")
        print(np.max(tU))
        print(np.min(tL))
        # print(tL)

# print(A)
# print(b)
# print(C)
# print(d)
# xi = 0.3 * N_act * (N_act - 1)
        for xi in [1e2, 1e3, 1e4]:
            for M in range(7, 14, 3):
                print("M : ", M)
                print("XI : ",xi)
                log = {}
                z_dro = FCP_DRO(m, K, N, N_act, M, tL, tU, A, b, C, d, s, xi)
                # z_erm = FCP_ERM(m, K, N, M, tL, tU, A, b, C, d, s)
                z_erm = FCP_DRO(m, K, N, N_act, M, tL, tU, A, b, C, d, s, 0)
                print("Z DRO : ",z_dro)
                print("Z ERM : ",z_erm)
                # log_file = './logs/FCP_weights_M{}_xi{}'.format(M,xi)
                log_file = './logs/filtered{}N_{}upper_{}lower_{}train_FCP_weights_M{}_xi{}'.format(N, upper, lower, N_train, M,xi)
                log['weights'] = [z_dro, z_erm]
                values_DRO = FCP_values(z_dro, theta)
                values_ERM = FCP_values(z_erm, theta)
                print("Train")
                avg_d, std_d = weighted_avg_and_std(values_DRO, s[:-1])
                avg_e, std_e = weighted_avg_and_std(values_ERM, s[:-1])
                print(avg_d)
                print(avg_e)
                print(std_d)
                print(std_e)
                print("Test")
                l = int(0.05 * (count - N_train))
                values_DRO = FCP_values(z_dro, test_theta)
                values_ERM = FCP_values(z_erm, test_theta)
                arg_d = np.argsort(values_DRO)
                arg_e = np.argsort(values_ERM)
                avg_d, std_d = weighted_avg_and_std(values_DRO[arg_d[:l]], test_weights[arg_d[:l]])
                avg_e, std_e = weighted_avg_and_std(values_ERM[arg_e[:l]], test_weights[arg_e[:l]])
                print(avg_d)
                print(avg_e)
                avg_d, std_d = weighted_avg_and_std(values_DRO, test_weights)
                avg_e, std_e = weighted_avg_and_std(values_ERM, test_weights)
                print(avg_d)
                print(avg_e)
                print(std_d)
                print(std_e)
                np.save(log_file, log)
