import ast
import numpy as np
from utils import *
import os 
from paretoset import paretoset
from time import time
import pandas as pd
from sklearn import preprocessing
from copy import deepcopy
from pygmo import hypervolume
from Polyhedron import Polyhedron
import torch

method = "VOGP" #"MESMO"

path1 = os.path.join(os.getcwd(),"results","BRANIN")
path2 = os.path.join(os.getcwd(),"results","SINE_BIG")
path3 = os.path.join(os.getcwd(),"results","BC")
path4 = os.path.join(os.getcwd(),"results","SINE")
path5 = os.path.join(os.getcwd(),"results","SNW")
path6 = os.path.join(os.getcwd(),"results","JAHS")
path7 = os.path.join(os.getcwd(),"results","CHEM")
path8 = os.path.join(os.getcwd(),"results","OKA")
path9 = os.path.join(os.getcwd(),"results","ZINC")
path10 = os.path.join(os.getcwd(),"results","DEBUG")
path11 = os.path.join(os.getcwd(),"results","DTLZ1")
path12 = os.path.join(os.getcwd(),"results","SnAR")



#for path in [path1,path2,path3,path4,path5]: #TODO: Match the data files with result folder names.
for path in [path5,path3,path8,path12]:
    dataset_name = os.path.basename(os.path.normpath(path))
    if path == path1:
        x = np.load(os.path.join('datasets',"braninx.npy"))
        y = np.load(os.path.join('datasets',"braniny.npy"))
        n = 1000
        m=2
        d=2
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path2:
        x = np.load(os.path.join('datasets',"sinex.npy"))
        y = np.load(os.path.join('datasets',"siney.npy"))
        d = 2
        m = 2
        mu=y
        n = 1000
    elif path == path3:
        x = np.load(os.path.join('datasets',"braninx_small.npy"))
        y = np.load(os.path.join('datasets',"braniny_small.npy"))
        n = 250
        m=2
        d=2
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path4:
        x = np.load(os.path.join('datasets',"sinex_small.npy"))
        y = np.load(os.path.join('datasets',"siney_small.npy"))
        n = 250
        m=2
        d=2
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path5:
        n = 206
        d = 3  # The input dimension
        m = 2 # The output dimension
        datafile = os.path.join('datasets','sort_256.csv')
        designs = np.genfromtxt(datafile, delimiter=';')
        y = np.copy(designs[:,3:])
        y[:,0] = -y[:,0]
        x = designs[:,:3]
        raw_y = np.copy(y)
        scaler = preprocessing.StandardScaler().fit(y[:,0].reshape(-1,1)) #TODO: save the scaled data.
        y[:,0] = scaler.transform(y[:,0].reshape(-1,1)).reshape(-1,)
        scaler = preprocessing.StandardScaler().fit(y[:,1].reshape(-1,1))
        y[:,1] = scaler.transform(y[:,1].reshape(-1,1)).reshape(-1,) 
        mu = deepcopy(y)
    elif path == path6:
        x = np.load(os.path.join('datasets',"jahsx.npy"))
        y = np.load(os.path.join('datasets',"jahsy.npy"))
        n = 785
        m=2
        d=6
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path8:
        x = np.load(os.path.join('datasets',"okax.npy"))
        y = np.load(os.path.join('datasets',"okay.npy"))
        n = 250
        m=2
        d=3
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path9:
        x = np.load(os.path.join('datasets',"Chem_x_small.npy"))
        y = np.load(os.path.join('datasets',"Chem_y_small.npy"))
        n = 250
        m=2
        d=2048
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path10:
        x = np.load(os.path.join('datasets',"debug_x.npy"))
        y = np.load(os.path.join('datasets',"debug_y.npy"))
        
        n = x.shape[0]
        d = 2
        m = 2 
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path11:
        x = np.load(os.path.join('datasets',"dtlz1x.npy"))
        y = np.load(os.path.join('datasets',"dtlz1y.npy"))
        n = 250
        m=2
        d=3
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path12:
        x = np.load(os.path.join('datasets',"SnAr_x.npy"))
        y = np.load(os.path.join('datasets',"SnAr_y.npy"))
        n = 950
        m=2
        d=4
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 



    epsilon = 0.3
    NF1_sum = 0
    NF2_sum = 0
    success = 0
    PM_sum = 0
    SC_sum = 0
    round_sum = 0
    result_dict = dict()
    hv_sum=0
    pareto_accuracy = 0
    pareto_recall = 0
    pareto_precision = 0
    coverage = 0
    
    angle = "90"
    cone_text = r"$C_{\theta}=\pi/2$"
    A = np.eye(m)


    alpha_vec = get_alpha_vec(A)
    p_opt = get_pareto_set(mu, A, alpha_vec)
    b = np.zeros((m,))
    C = Polyhedron(A = A, b = b)
    Delta = get_delta(mu, A, alpha_vec)



    success_ratio_list = list()
    pareto_accuracy_list = list()
    pareto_recall_list = list()
    pareto_precision_list = list()
    coverage_list = list()
    NF1_list = list()
    NF2_list = list()
    PM_list = list()
    SC_list = list()
    success_ratio_list_1 = list()
    success_ratio_list_2 = list()
    hv_list = list()


    for file in os.listdir(os.path.join(path,angle)):
        exp_dict = np.load(os.path.join(path,angle,file),allow_pickle=True)
        P_hat = exp_dict.item()["P_hat"].detach().cpu().numpy()
        mask_hat = exp_dict.item()["P_hat_mask"]
        sample_count = exp_dict.item()["SC"]
        round_count = exp_dict.item()["RC"]

    

        Pf_hat = mu[mask_hat]

        mask_true = np.zeros(n, dtype=bool)
        mask_true[p_opt] = True
        fail_2 = np.count_nonzero(Delta[mask_hat] > 2*epsilon) 

        
        p_subopt = np.setdiff1d(np.where(mask_hat==True)[0], np.where(mask_true==True)[0])#1D array of values in ar1 that are not in ar2
        p_optmiss= np.setdiff1d(np.where(mask_true==True)[0], np.where(mask_hat==True)[0])
        fail1_points = get_uncovered_set(p_optmiss, np.where(mask_hat==True)[0], mu, epsilon, A)

        SR1_rate = 100*(1.-len(fail1_points)/len(p_opt))
        SR2_rate = 100*(1-fail_2/P_hat.shape[0])

        NF1 = len(fail1_points)
        NF2 = fail_2

        PM = 100*(len(p_optmiss)/len(p_opt))



        success += 100*(SR1_rate == 100 and SR2_rate == 100) 
        NF1_sum += NF1
        NF2_sum += NF2
        PM_sum += PM
        SC_sum += sample_count
        round_sum += round_count

        hv = hypervolume(points = -Pf_hat)


        pareto_accuracy += (mask_hat == mask_true).sum()*100/n
        pareto_precision += (np.logical_and(mask_hat , mask_true)).sum()*100/mask_hat.sum()
        pareto_recall += (np.logical_and(mask_hat , mask_true)).sum()*100/mask_true.sum()

        coverage += 100*(1-(fail_2)/mask_true.sum())

        success_ratio_list.append(100*(SR1_rate == 100 and SR2_rate == 100) ) 
        success_ratio_list_1.append(SR1_rate) 
        success_ratio_list_2.append(SR2_rate) 
        
        pareto_accuracy_list.append((mask_hat == mask_true).sum()*100/n) 
        pareto_recall_list.append((np.logical_and(mask_hat , mask_true)).sum()*100/mask_true.sum()) 
        pareto_precision_list.append((np.logical_and(mask_hat , mask_true)).sum()*100/mask_hat.sum()) 
        coverage_list.append(100*(1-(fail_2)/mask_true.sum())) 
        NF1_list.append(NF1) 
        NF2_list.append(NF2) 
        PM_list.append(PM) 
        SC_list.append(sample_count) 
        hv_list.append(hv.compute([5]*m))




    exp_size = len(os.listdir(os.path.join(path,angle)))



    result_dict["Dataset"] = dataset_name
    #result_dict["HV Score"] = hv_sum/exp_size
    result_dict["Success Ratio"] =  str(round(np.mean(success_ratio_list).item(),2))#+" +-" + str(np.std(success_ratio_list).item())  
    result_dict["Success 1 Ratio"] =  str(round(np.mean(success_ratio_list_1).item(),2))#+" +-" + str(np.std(success_ratio_list).item())  
    result_dict["Success 2 Ratio"] =  str(round(np.mean(success_ratio_list_2).item(),2))#+" +-" + str(np.std(success_ratio_list).item())  
    result_dict["Hypervolume"] =  str(round(np.mean(hv_list).item(),2))+" +-" + str(round(np.std(hv_list).item(),2)) 
    result_dict["Pareto Accuracy"] = str(round(np.mean(pareto_accuracy_list).item(),2))+" +-" + str(round(np.std(pareto_accuracy_list).item(),2)) 
    result_dict["Pareto Recall"] =  str(round(np.mean(pareto_recall_list).item(),2))+" +-" + str(round(np.std(pareto_recall_list).item(),2))  
    result_dict["Pareto Precision"] =  str(round(np.mean(pareto_precision_list).item(),2))+" +-" + str(round(np.std(pareto_precision_list).item(),2))  
    result_dict["Front Coverage"] =  str(round(np.mean(coverage_list).item(),2))+" +-" + str(round(np.std(coverage_list).item(),2))  
    result_dict["NF2"] =  str(round(np.mean(NF2_list).item(),2))+" +-" + str(round(np.std(NF2_list).item(),2))  
    result_dict["NF1"] = str(round(np.mean(NF1_list).item(),2))+" +-" + str(round(np.std(NF1_list).item(),2))  
    result_dict["PM"] = str(round(np.mean(PM_list).item(),2))+" +-" + str(round(np.std(PM_list).item(),2))  
    result_dict["SC"] = str(round(np.mean(SC_list).item(),2))+" +-" + str(round(np.std(SC_list).item(),2))  




    df = pd.read_csv(os.path.join('results','tables','result_table_' + "comp" + '.csv'))
    df = pd.concat([df,pd.DataFrame([result_dict])],ignore_index=True)
    df.to_csv(os.path.join('results','tables','result_table_' + "comp" + '.csv'),index=False)


    print("Done")
    




""" f = open("input_output.txt", "r")
results = f.readlines()
f.close()

res = [i.strip().split("---")[1] for i in results]
res = [ast.literal_eval(i) for i in res]
res = np.array(res)

t0 = time()
 """

""" p_opt = get_pareto_set(mu, A, alpha_vec)


mask_hat = np.zeros(n, dtype=bool)
p_opt_alg = get_pareto_set(res, A, alpha_vec)
y_hat = res[p_opt_alg]
hat_indices = [(np.where((mu==i))[0][0]==np.where((mu==i))[0][1])*np.where((mu==i))[0][1] for i in y_hat]
mask_hat[hat_indices] = True """


""" time1 = time()

mask_hat = np.zeros(n, dtype=bool)
hotels = pd.DataFrame({"func1": res[:,0].reshape(-1,), "func2": res[:,1].reshape(-1,)})
mask_hat_fast = paretoset(hotels, sense=["max", "max"])
y_hat = res[mask_hat_fast]
hat_indices = [(np.where((mu==i))[0][0]==np.where((mu==i))[0][1])*np.where((mu==i))[0][1] for i in y_hat]
mask_hat[hat_indices] = True

hotels = pd.DataFrame({"func1": mu[:,0].reshape(-1,), "func2": mu[:,1].reshape(-1,)})
mask_true_fast = paretoset(hotels, sense=["max", "max"])
p_opt = np.where(mask_true_fast==True)[0]

time2 = time()

#y_hat = res[mask_hat]

Delta = get_delta(mu, A, alpha_vec)
mask_true = np.zeros(n, dtype=bool)
mask_true[p_opt] = True
fail_2 = np.count_nonzero(Delta[mask_hat] > 2*epsilon) #This should be 2 epsilon for publication


p_subopt = np.setdiff1d(np.where(mask_hat==True)[0], np.where(mask_true==True)[0])#1D array of values in ar1 that are not in ar2
p_optmiss= np.setdiff1d(np.where(mask_true==True)[0], np.where(mask_hat==True)[0])
fail1_points = get_uncovered_set(p_optmiss, np.where(mask_hat==True)[0], mu, epsilon, A)

SR1_rate = 100*(1.-len(fail1_points)/len(p_opt))
SR2_rate = 100*(1-fail_2/sum(mask_hat))

NF1 = len(fail1_points)
NF2 = fail_2

PM = 100*(len(p_optmiss)/len(p_opt))

uniq_res = np.unique(res,axis=0) 
print(sum(mask_hat),PM,SR1_rate,SR2_rate,NF1,NF2,uniq_res.shape[0],res.shape[0])
 """
#success += 100*(SR1_rate == 100 and SR2_rate == 100) 