import argparse
import os

import torch
import gpytorch

from VectorEpsilonPAL import VectorEpsilonPAL
from OptimizationProblem import OptimizationProblem
from GaussianProcessModel import GaussianProcessModel,GaussianProcessModelDependent
from Polyhedron import Polyhedron
from utils import *
from utils_plot import *
from copy import deepcopy

import numpy as np
#mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']
import time
from sklearn import preprocessing
from paretoset import paretoset
import pandas as pd
import uuid
from scipy.stats import boxcox
from sklearn.preprocessing import power_transform
from sklearn.preprocessing import PowerTransformer

from multiprocessing import freeze_support

parser = argparse.ArgumentParser(description='Run the algorithm.')
parser.add_argument('--dataset', dest='dataset',help='Which dataset to run the algorithm with. The choices are: SNW, SINE, JAHS')
#parser.add_argument('--cone_angle', dest='angle',help='The angle of the ordering cone.')
parser.add_argument('--epsilon', dest='epsilon',type = float,default = 0.1,help='the error margin.')
parser.add_argument('--delta', dest='delta',type = float,default = 0.05,help='The precision parameter.')
parser.add_argument('--TrainGPDuringAlgorithm', dest='TrainGPDuringAlgorithm',default = False,help=' Whether to run the hyperparameter training between rounds.')
parser.add_argument('--noise_std', dest='std',type = float, default = 0.1,help='The noise std.')
parser.add_argument('--device', dest='DEVICE',default = "cpu",help='Choosing the device on which the GP will work on.')
parser.add_argument("--gp_dependent",dest="gp_dependent",type = int, default= 1, help = "Whether the gp will be dependent or not.")
parser.add_argument('--seed', dest='seed',type = int,default = 0,help='the random seed.')
args = parser.parse_args()

#if __name__ == "__main__":


TrainGPDuringAlgorithm = args.TrainGPDuringAlgorithm
epsilon = args.epsilon
delta=args.delta
DEVICE = args.DEVICE
batched = 0


std = args.std
noise_variance  = std**2
seed = args.seed

np.random.seed(seed)
torch.random.manual_seed(seed)

# freeze_support()

if args.dataset == "SNW":
    n = 206
    d = 3  # The input dimension
    m = 2 # The output dimension
    jumpstart_gp_n = 10


    datafile = os.path.join('datasets','sort_256.csv')
    designs = np.genfromtxt(datafile, delimiter=';')
    y = np.copy(designs[:,3:])
    y[:,0] = -y[:,0]
    x = designs[:,:3]
    raw_y = np.copy(y)


    """ y[:,0] -= np.mean(y[:,0]) 
    y[:,1] -= np.mean(y[:,1]) 

    max_norm = np.max(np.linalg.norm(y,axis=1))
    y = y/max_norm """

    scaler = preprocessing.StandardScaler().fit(y[:,0].reshape(-1,1))
    y[:,0] = scaler.transform(y[:,0].reshape(-1,1)).reshape(-1,)
    scaler = preprocessing.StandardScaler().fit(y[:,1].reshape(-1,1))
    y[:,1] = scaler.transform(y[:,1].reshape(-1,1)).reshape(-1,) 
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)

    mu = np.copy(y)
    problem_model = OptimizationProblem(x, y, std)
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d), num_tasks=m, rank=m)
    #kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=m, rank=m)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    paths= os.path.join("results","SNW")
if args.dataset == "SINE":
    """ func1 = lambda x: 2 * np.sin(np.pi * x[:,0]) * np.sin(np.pi * x[:,1]) + 4 * np.sin(2 * np.pi * x[:,0]) * np.sin(2 * np.pi * x[:,1])
    func2 = lambda x: 2 * np.sin(np.pi * x[:,0]) * np.sin(np.pi * x[:,1]) - 6 * np.sin(2 * np.pi * x[:,0]) * np.sin(2 * np.pi * x[:,1])
    n = args.n
    m=2
    d=2
    jumpstart_gp_n = 10
    x = np.random.uniform(low = -1, high = 1, size = (n, d))
    y = np.empty((n, m))

    y[:,0] = func1(x)
    y[:,1] = func2(x)

    raw_y = np.copy(y)

    
    scaler = preprocessing.StandardScaler().fit(y[:,0].reshape(-1,1))
    y[:,0] = scaler.transform(y[:,0].reshape(-1,1)).reshape(-1,)
    scaler = preprocessing.StandardScaler().fit(y[:,1].reshape(-1,1))
    y[:,1] = scaler.transform(y[:,1].reshape(-1,1)).reshape(-1,)  """

    x = np.load(os.path.join('datasets',"sinex.npy"))
    y = np.load(os.path.join('datasets',"siney.npy"))
    n = 1000
    m=2
    d=2
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d), num_tasks=m, rank=m)
    paths= os.path.join("results","SINE")
if args.dataset == "DTLZ1":
    x = np.load(os.path.join('datasets',"dtlz1x.npy"))
    y = np.load(os.path.join('datasets',"dtlz1y.npy"))
    n = 250
    m=2
    d=3
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d), num_tasks=m, rank=m)
    paths= os.path.join("results","DTLZ1")
if args.dataset == "JAHS":
    """ pth = "jahs_bench_data/metric_data/cifar10/train_set.pkl.gz"
    df = pd.read_pickle(pth)
    df = df[df["features"]["epoch"] == 100]
    df = df[df["features"]["Activation"] == "ReLU"]
    df = df[df["features"]["N"] == 1]
    df = df[df["features"]["Optimizer"] == "SGD"]
    df = df[df["features"]["TrivialAugment"] == False]
    df = df[df["features"]["W"] == 4]
    df = df[df["features"]["Resolution"] == 0.5]


    df_x = df["features"]
    df_y = df["labels"]
    df = None

    df_x = df_x[[    #'Activation',
            #      'LearningRate',
            #                 'N',
                           'Op1',
                           'Op2',
                           'Op3',
                           'Op4',
                           'Op5',
                           'Op6'
            #         'Optimizer',
            #        'Resolution',
            #    'TrivialAugment',
            #                 'W',
            #       'WeightDecay',
            #             'epoch'
            ]
            ]
                        
    Objectives = [ #'valid-acc',
                # 'latency',
                'runtime',
                #   'FLOPS',
                # 'size_MB',
                'test-acc',
                #'train-acc'
                ]

    d = df_x.shape[1] # The input dimension
    m = len(Objectives) # The output dimension
    df_y = df_y[Objectives]
    x = np.array(df_x)
    y = np.array(df_y)
    y[:,0] = -y[:,0]
    x,indices = np.unique(x,axis=0,return_index=True)
    y = y[indices]
    raw_y = np.copy(y)


    outlier_list=list()
    for ind, col in  enumerate(y.T):
        percentiles = np.percentile(col,[25,75]) #calclating 25 and 75 th percentiles
        IQR = percentiles[1]-percentiles[0] #calulating interquartile range
        right = col>(percentiles[1]+1.5*IQR)
        left = col <(percentiles[0]-1.5*IQR)
        outliers = right+left
        outlier_list.append(outliers)
        
    outliers = outlier_list[0]+outlier_list[1]

    y = y[np.logical_not(outliers)]
    x = x[np.logical_not(outliers)]

    scaler = preprocessing.StandardScaler().fit(y[:,0].reshape(-1,1))
    y[:,0] = scaler.transform(y[:,0].reshape(-1,1)).reshape(-1,)
    scaler = preprocessing.StandardScaler().fit(y[:,1].reshape(-1,1))
    y[:,1] = scaler.transform(y[:,1].reshape(-1,1)).reshape(-1,)  """

    x = np.load(os.path.join('datasets',"jahsx.npy"))
    y = np.load(os.path.join('datasets',"jahsy.npy"))
    m = 2
    d = 6
 
    n = y.shape[0]

    jumpstart_gp_n =n
    jumpstart_indexes = np.random.choice(n, jumpstart_gp_n, replace=False)
    """ pt = PowerTransformer()
    pt.fit(y)
    y = pt.transform(y) """
    x= torch.from_numpy(x).float().to(DEVICE)
    y= torch.from_numpy(y).float().to(DEVICE)

    mu = y
    problem_model = OptimizationProblem(x, y, std)
    #kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel_list = [CategoricalKernel() for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(CategoricalKernel() , num_tasks=m, rank=m)
    paths= os.path.join("results","JAHS")
if args.dataset == "JAHS_FLOAT":
    x = np.load(os.path.join('datasets',"jahsx_float.npy"))
    y = np.load(os.path.join('datasets',"jahsy_float.npy"))
    m = 2
    d = 4
 
    n = y.shape[0]

    jumpstart_gp_n =n
    jumpstart_indexes = np.random.choice(n, jumpstart_gp_n, replace=False)
    """ pt = PowerTransformer()
    pt.fit(y)
    y = pt.transform(y) """
    x= torch.from_numpy(x).float().to(DEVICE)
    y= torch.from_numpy(y).float().to(DEVICE)

    mu = y
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
if args.dataset == "GP":
    x = np.load(os.path.join('datasets',"gp_sample_x.npy"))
    y = np.load(os.path.join('datasets',"gp_sample_y.npy"))
    n = 250
    m=2
    d=1
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=m, rank=m)
    paths= os.path.join("results","GP")
if args.dataset == "ZINC":
    x = np.load(os.path.join('datasets',"Chem_x_small_2048b.npy"))
    y = np.load(os.path.join('datasets',"Chem_y_small_2048b.npy"))
    m = 2
    n=250
    d = 2048
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    kernel_list = [TanimotoKernel() for _ in range(m)]
    problem_model = OptimizationProblem(x, y, std)
    kernel = gpytorch.kernels.MultitaskKernel(TanimotoKernel() , num_tasks=m, rank=m)
    paths= os.path.join("results","ZINC")
    
""" if args.dataset == "ZINC":
    x = np.load(os.path.join('datasets',"Chem_x_ssk.npy"),allow_pickle=True)
    y = np.load(os.path.join('datasets',"Chem_y_ssk.npy"),allow_pickle=True)
    m = 2
    n=250
    d = None
    #x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    kernel_list = [PyTorchSSK(n=5, lbda=0.8) for _ in range(m)]
    problem_model = OptimizationProblem(x, y, std)
    kernel = gpytorch.kernels.MultitaskKernel(PyTorchSSK(n=5, lbda=0.8) , num_tasks=m, rank=m)
    paths= os.path.join("results","ZINC") """
if args.dataset == "SINE_SMALL":
    x = np.load(os.path.join('datasets',"sinex_small.npy"))
    y = np.load(os.path.join('datasets',"siney_small.npy"))
    n = 250
    m=2
    d=2
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=m, rank=m)
    paths= os.path.join("results","SINE_SMALL")
if args.dataset == "BC":
    x = np.load(os.path.join('datasets',"braninx_small.npy"))
    y = np.load(os.path.join('datasets',"braniny_small.npy"))
    n = 250
    m=2
    d=2
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d), num_tasks=m, rank=m)
    paths= os.path.join("results","BC")
if args.dataset == "OKA":
    x = np.load(os.path.join('datasets',"okax.npy"))
    y = np.load(os.path.join('datasets',"okay.npy"))
    n = 250
    m=2
    d=3
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.MaternKernel(ard_num_dims =d), num_tasks=m, rank=m)
    #kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=m, rank=m)
    paths= os.path.join("results","OKA")
if args.dataset == "SnAR":
    x = np.load(os.path.join('datasets',"SnAr_x.npy"))
    y = np.load(os.path.join('datasets',"SnAr_y.npy"))
    n =950
    m=2
    d=4
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d), num_tasks=m, rank=m)
    paths= os.path.join("results","SnAR")
if args.dataset == "penicillin":
    x = np.load(os.path.join('datasets',"penicillin_x.npy"))
    y = np.load(os.path.join('datasets',"penicillin_y.npy"))
    n = 250
    m=2
    d=7
    x= torch.from_numpy(x).to(DEVICE)
    y= torch.from_numpy(y).to(DEVICE)
    mu = np.empty((n, m))
    mu[:, 0] = y[:, 0] 
    mu[:, 1] = y[:, 1] 
    problem_model = OptimizationProblem(x, y, std)
    kernel_list = [gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) for _ in range(m)]
    kernel = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d), num_tasks=m, rank=m)
    paths= os.path.join("results","penicillin")

#indexes = np.random.choice(len(x), jumpstart_gp_n, replace=True)

x_sample = x#torch.tensor.new_tensor(x)
y_sample = y#torch.tensor.new_tensor(y)

filename = str(uuid.uuid4())

for angle in ["45","90","135"]:
    if angle == "135": 
        theta_135 = 3*np.pi/4
        W_135_1 = np.array([-np.tan(np.pi/4-theta_135/2), 1])
        W_135_2 = np.array([-np.tan(np.pi/4+theta_135/2), 1])
        W_135_1 = W_135_1/np.linalg.norm(W_135_1)
        W_135_2 = W_135_2/np.linalg.norm(W_135_2)
        cone_text = r"$C_{\theta}=3\pi/4$"
        A = np.vstack((W_135_1, W_135_2))

    elif angle == "45":
        theta_45 = np.pi/4
        W_45_1 = np.array([-np.tan(np.pi/4-theta_45/2), 1])
        W_45_2 = np.array([+np.tan(np.pi/4+theta_45/2), -1])
        W_45_1 = W_45_1/np.linalg.norm(W_45_1)
        W_45_2 = W_45_2/np.linalg.norm(W_45_2)
        cone_text = r"$C_{\theta}=\pi/4$"
        A  = np.vstack((W_45_1, W_45_2))
    elif angle == "90":
        cone_text = r"$C_{\theta}=\pi/2$"
        A = np.eye(m)
    else:
        raise ValueError('The given cone angle is invalid.')
    
    alpha_vec = get_alpha_vec(A)
    p_opt = get_pareto_set(mu, A, alpha_vec)
    b = np.zeros((m,))
    C = Polyhedron(A = A, b = b)
    



    if args.gp_dependent:
        kernel_copied =  deepcopy(kernel) 
        gp = GaussianProcessModelDependent(d = d, m = m, noise_variance = noise_variance, x_sample = x_sample, y_sample = y_sample,
                kernel = kernel, verbose = True,device=DEVICE,train_during_alg=TrainGPDuringAlgorithm)

    else:
        kernel_copied = [deepcopy(kern) for kern in kernel_list]
        gp = GaussianProcessModel(d = d, m = m, noise_variance = noise_variance, x_sample = x_sample, y_sample = y_sample,
            kernel_list = kernel_copied, verbose = True,device=DEVICE,train_during_alg=TrainGPDuringAlgorithm)
    alg = VectorEpsilonPAL(problem_model = problem_model, cone = C, epsilon = epsilon, delta = delta, gp = gp,obj_dim=m,maxiter=None,batched= batched)
    time0 = time.time()
    pareto_set = alg.algorithm()
    duration = time.time()-time0
    print(f"time spent = {duration}")
    print("Pareto set")
    print_list(pareto_set)

    mask_hat = np.zeros(n, dtype=bool)
    p_hat_indexes = [design.design_index for design in pareto_set]
    mask_hat[p_hat_indexes] = True

    y_hat = mu[mask_hat]
    Delta = get_delta(mu, A, alpha_vec)
    mask_true = np.zeros(n, dtype=bool)
    mask_true[p_opt] = True


    
    P_hat = x[mask_hat]

    save_dict = dict()
    save_dict["P_hat"]=P_hat
    save_dict["P_hat_mask"]=mask_hat
    save_dict["SC"]= alg.sample_count
    save_dict["RC"]= alg.t



    np.save(os.path.join(paths,angle,filename+'.npy'), save_dict)




