import ast
import numpy as np
from utils import *
import os 
from paretoset import paretoset
from time import time
import pandas as pd
from sklearn import preprocessing
from copy import deepcopy

from Polyhedron import Polyhedron
import torch


path1 = os.path.join(os.getcwd(),"results","BRANIN")
path2 = os.path.join(os.getcwd(),"results","SINE_BIG")
path3 = os.path.join(os.getcwd(),"results","BC")
path4 = os.path.join(os.getcwd(),"results","SINE")
path5 = os.path.join(os.getcwd(),"results","SNW")
path6 = os.path.join(os.getcwd(),"results","JAHS")
path7 = os.path.join(os.getcwd(),"results","CHEM")
path8 = os.path.join(os.getcwd(),"results","OKA")
path9 = os.path.join(os.getcwd(),"results","ZINC")
path10 = os.path.join(os.getcwd(),"results","DEBUG")
path11 = os.path.join(os.getcwd(),"results","GP")
path12 = os.path.join(os.getcwd(),"results","DTLZ1")
path13 = os.path.join(os.getcwd(),"results","SnAR")



#for path in [path1,path2,path3,path4,path5]: #TODO: Match the data files with result folder names.
for path in [path13]:
    dataset_name = os.path.basename(os.path.normpath(path))
    if path == path1:
        x = np.load(os.path.join('datasets',"braninx.npy"))
        y = np.load(os.path.join('datasets',"braniny.npy"))
        n = 1000
        m=2
        d=2
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path2:
        x = np.load(os.path.join('datasets',"sinex.npy"))
        y = np.load(os.path.join('datasets',"siney.npy"))
        d = 2
        m = 2
        mu=y
        n = 1000
    elif path == path3:
        x = np.load(os.path.join('datasets',"braninx_small.npy"))
        y = np.load(os.path.join('datasets',"braniny_small.npy"))
        n = 250
        m=2
        d=2
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path4:
        x = np.load(os.path.join('datasets',"sinex_small.npy"))
        y = np.load(os.path.join('datasets',"siney_small.npy"))
        n = 250
        m=2
        d=2
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path5:
        n = 206
        d = 3  # The input dimension
        m = 2 # The output dimension
        datafile = os.path.join('datasets','sort_256.csv')
        designs = np.genfromtxt(datafile, delimiter=';')
        y = np.copy(designs[:,3:])
        y[:,0] = -y[:,0]
        x = designs[:,:3]
        raw_y = np.copy(y)
        scaler = preprocessing.StandardScaler().fit(y[:,0].reshape(-1,1)) #TODO: save the scaled data.
        y[:,0] = scaler.transform(y[:,0].reshape(-1,1)).reshape(-1,)
        scaler = preprocessing.StandardScaler().fit(y[:,1].reshape(-1,1))
        y[:,1] = scaler.transform(y[:,1].reshape(-1,1)).reshape(-1,) 
        mu = deepcopy(y)
    elif path == path6:
        x = np.load(os.path.join('datasets',"jahsx.npy"))
        y = np.load(os.path.join('datasets',"jahsy.npy"))
        n = 785
        m=2
        d=6
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path8:
        x = np.load(os.path.join('datasets',"okax.npy"))
        y = np.load(os.path.join('datasets',"okay.npy"))
        n = 250
        m=2
        d=3
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path9:
        x = np.load(os.path.join('datasets',"Chem_x_small.npy"))
        y = np.load(os.path.join('datasets',"Chem_y_small.npy"))
        n = 250
        m=2
        d=2048
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path10:
        x = np.load(os.path.join('datasets',"debug_x.npy"))
        y = np.load(os.path.join('datasets',"debug_y.npy"))
        
        n = x.shape[0]
        d = 2
        m = 2 
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 

    elif path == path11:
        x = np.load(os.path.join('datasets',"gp_sample_x.npy"))
        y = np.load(os.path.join('datasets',"gp_sample_y.npy"))
        
        n = x.shape[0]
        d = 1
        m = 2 
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path12:
        x = np.load(os.path.join('datasets',"dtlz1x.npy"))
        y = np.load(os.path.join('datasets',"dtlz1y.npy"))
        n = 250
        m=2
        d=3
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    elif path == path13:
        x = np.load(os.path.join('datasets',"SnAr_x.npy"))
        y = np.load(os.path.join('datasets',"SnAr_y.npy"))
        n = 950
        m=2
        d=4
        mu = np.empty((n, m))
        mu[:, 0] = y[:, 0] 
        mu[:, 1] = y[:, 1] 
    

    
        


    for angle in ["45","90","135"]:
        epsilon = 0.3
        NF1_sum = 0
        NF2_sum = 0
        success = 0
        PM_sum = 0
        SC_sum = 0
        round_sum = 0
        result_dict = dict()
        hv_sum=0
        pareto_accuracy = 0
        pareto_recall = 0
        pareto_precision = 0
        coverage = 0
        if angle == "135": 
            theta_135 = 3*np.pi/4
            W_135_1 = np.array([-np.tan(np.pi/4-theta_135/2), 1])
            W_135_2 = np.array([-np.tan(np.pi/4+theta_135/2), 1])
            W_135_1 = W_135_1/np.linalg.norm(W_135_1)
            W_135_2 = W_135_2/np.linalg.norm(W_135_2)
            cone_text = r"$C_{\theta}=3\pi/4$"
            A = np.vstack((W_135_1, W_135_2))

        elif angle == "45":
            theta_45 = np.pi/4
            W_45_1 = np.array([-np.tan(np.pi/4-theta_45/2), 1])
            W_45_2 = np.array([+np.tan(np.pi/4+theta_45/2), -1])
            W_45_1 = W_45_1/np.linalg.norm(W_45_1)
            W_45_2 = W_45_2/np.linalg.norm(W_45_2)
            cone_text = r"$C_{\theta}=\pi/4$"
            A  = np.vstack((W_45_1, W_45_2))
        elif angle == "90":
            cone_text = r"$C_{\theta}=\pi/2$"
            A = np.eye(m)
        else:
            raise ValueError('The given cone angle is invalid.')

        alpha_vec = get_alpha_vec(A)
        p_opt = get_pareto_set(mu, A, alpha_vec)
        b = np.zeros((m,))
        C = Polyhedron(A = A, b = b)
        Delta = get_delta(mu, A, alpha_vec)



        success_ratio_list = list()
        pareto_accuracy_list = list()
        pareto_recall_list = list()
        pareto_precision_list = list()
        coverage_list = list()
        NF1_list = list()
        NF2_list = list()
        PM_list = list()
        SC_list = list()
        success_ratio_list_1 = list()
        success_ratio_list_2 = list()


        for file in os.listdir(os.path.join(path,angle)):
            exp_dict = np.load(os.path.join(path,angle,file),allow_pickle=True)
            P_hat = exp_dict.item()["P_hat"].detach().cpu().numpy()
            mask_hat = exp_dict.item()["P_hat_mask"]
            sample_count = exp_dict.item()["SC"]
            round_count = exp_dict.item()["RC"]

        

            Pf_hat = mu[mask_hat]

            mask_true = np.zeros(n, dtype=bool)
            mask_true[p_opt] = True
            fail_2 = np.count_nonzero(Delta[mask_hat] > 2*epsilon) 

            
            p_subopt = np.setdiff1d(np.where(mask_hat==True)[0], np.where(mask_true==True)[0])#1D array of values in ar1 that are not in ar2
            p_optmiss= np.setdiff1d(np.where(mask_true==True)[0], np.where(mask_hat==True)[0])
            fail1_points = get_uncovered_set(p_optmiss, np.where(mask_hat==True)[0], mu, epsilon, A)

            SR1_rate = 100*(1.-len(fail1_points)/len(p_opt))
            SR2_rate = 100*(1-fail_2/P_hat.shape[0])

            NF1 = len(fail1_points)
            NF2 = fail_2

            PM = 100*(len(p_optmiss)/len(p_opt))



            success += 100*(SR1_rate == 100 and SR2_rate == 100) 
            NF1_sum += NF1
            NF2_sum += NF2
            PM_sum += PM
            SC_sum += sample_count
            round_sum += round_count





            #hv = hypervolume(points = Pf_hat)
            #hv_score = hv.compute([3]*m)
            #hv_sum += hv_score
            pareto_accuracy += (mask_hat == mask_true).sum()*100/n
            pareto_precision += (np.logical_and(mask_hat , mask_true)).sum()*100/mask_hat.sum()
            pareto_recall += (np.logical_and(mask_hat , mask_true)).sum()*100/mask_true.sum()

            coverage += 100*(1-(fail_2)/mask_true.sum())

            success_ratio_list.append(100*(SR1_rate == 100 and SR2_rate == 100) ) 
            success_ratio_list_1.append(SR1_rate) 
            success_ratio_list_2.append(SR2_rate) 
            
            pareto_accuracy_list.append((mask_hat == mask_true).sum()*100/n) 
            pareto_recall_list.append((np.logical_and(mask_hat , mask_true)).sum()*100/mask_true.sum()) 
            pareto_precision_list.append((np.logical_and(mask_hat , mask_true)).sum()*100/mask_hat.sum()) 
            coverage_list.append(100*(1-(fail_2)/mask_true.sum())) 
            NF1_list.append(NF1) 
            NF2_list.append(NF2) 
            PM_list.append(PM) 
            SC_list.append(sample_count) 




        exp_size = len(os.listdir(os.path.join(path,angle)))



        result_dict["Dataset"] = dataset_name
        #result_dict["HV Score"] = hv_sum/exp_size
        result_dict["Success Ratio"] =  str(round(np.mean(success_ratio_list).item(),2))#+" +-" + str(np.std(success_ratio_list).item())  
        result_dict["Success 1 Ratio"] =  str(round(np.mean(success_ratio_list_1).item(),2))#+" +-" + str(np.std(success_ratio_list).item())  
        result_dict["Success 2 Ratio"] =  str(round(np.mean(success_ratio_list_2).item(),2))#+" +-" + str(np.std(success_ratio_list).item())  
        result_dict["Pareto Accuracy"] = str(round(np.mean(pareto_accuracy_list).item(),2))+" +-" + str(round(np.std(pareto_accuracy_list).item(),2)) 
        result_dict["Pareto Recall"] =  str(round(np.mean(pareto_recall_list).item(),2))+" +-" + str(round(np.std(pareto_recall_list).item(),2))  
        result_dict["Pareto Precision"] =  str(round(np.mean(pareto_precision_list).item(),2))+" +-" + str(round(np.std(pareto_precision_list).item(),2))  
        result_dict["Front Coverage"] =  str(round(np.mean(coverage_list).item(),2))+" +-" + str(round(np.std(coverage_list).item(),2))  
        result_dict["NF2"] =  str(round(np.mean(NF2_list).item(),2))+" +-" + str(round(np.std(NF2_list).item(),2))  
        result_dict["NF1"] = str(round(np.mean(NF1_list).item(),2))+" +-" + str(round(np.std(NF1_list).item(),2))  
        result_dict["PM"] = str(round(np.mean(PM_list).item(),2))+" +-" + str(round(np.std(PM_list).item(),2))  
        result_dict["SC"] = str(round(np.mean(SC_list).item(),2))+" +-" + str(round(np.std(SC_list).item(),2))  




        df = pd.read_csv(os.path.join('results','tables','result_table_' + angle + '.csv'))
        df = pd.concat([df,pd.DataFrame([result_dict])],ignore_index=True)
        df.to_csv(os.path.join('results','tables','result_table_' + angle + '.csv'),index=False)


    print("Done")
    




""" f = open("input_output.txt", "r")
results = f.readlines()
f.close()

res = [i.strip().split("---")[1] for i in results]
res = [ast.literal_eval(i) for i in res]
res = np.array(res)

t0 = time()
 """

""" p_opt = get_pareto_set(mu, A, alpha_vec)


mask_hat = np.zeros(n, dtype=bool)
p_opt_alg = get_pareto_set(res, A, alpha_vec)
y_hat = res[p_opt_alg]
hat_indices = [(np.where((mu==i))[0][0]==np.where((mu==i))[0][1])*np.where((mu==i))[0][1] for i in y_hat]
mask_hat[hat_indices] = True """


""" time1 = time()

mask_hat = np.zeros(n, dtype=bool)
hotels = pd.DataFrame({"func1": res[:,0].reshape(-1,), "func2": res[:,1].reshape(-1,)})
mask_hat_fast = paretoset(hotels, sense=["max", "max"])
y_hat = res[mask_hat_fast]
hat_indices = [(np.where((mu==i))[0][0]==np.where((mu==i))[0][1])*np.where((mu==i))[0][1] for i in y_hat]
mask_hat[hat_indices] = True

hotels = pd.DataFrame({"func1": mu[:,0].reshape(-1,), "func2": mu[:,1].reshape(-1,)})
mask_true_fast = paretoset(hotels, sense=["max", "max"])
p_opt = np.where(mask_true_fast==True)[0]

time2 = time()

#y_hat = res[mask_hat]

Delta = get_delta(mu, A, alpha_vec)
mask_true = np.zeros(n, dtype=bool)
mask_true[p_opt] = True
fail_2 = np.count_nonzero(Delta[mask_hat] > 2*epsilon) #This should be 2 epsilon for publication


p_subopt = np.setdiff1d(np.where(mask_hat==True)[0], np.where(mask_true==True)[0])#1D array of values in ar1 that are not in ar2
p_optmiss= np.setdiff1d(np.where(mask_true==True)[0], np.where(mask_hat==True)[0])
fail1_points = get_uncovered_set(p_optmiss, np.where(mask_hat==True)[0], mu, epsilon, A)

SR1_rate = 100*(1.-len(fail1_points)/len(p_opt))
SR2_rate = 100*(1-fail_2/sum(mask_hat))

NF1 = len(fail1_points)
NF2 = fail_2

PM = 100*(len(p_optmiss)/len(p_opt))

uniq_res = np.unique(res,axis=0) 
print(sum(mask_hat),PM,SR1_rate,SR2_rate,NF1,NF2,uniq_res.shape[0],res.shape[0])
 """
#success += 100*(SR1_rate == 100 and SR2_rate == 100) 