import datetime
import numpy as np
from gpcam import GPOptimizer
from gpcam import gpMCMC
import sys
from dask.distributed import Client
import torch
from scipy.interpolate import griddata
from dask.distributed import performance_report
from loguru import logger as log
import time
from scipy.stats import binom
from scipy.stats import beta
import scipy
import scipy.sparse as sparse
import socket
from distributed.scheduler import logger
import climate_kernel
from imate import logdet as imate_logdet


def main():
    import time
    start_time = time.time()
    descriptor = "_10Mill"
    wend = 10

    x_train = np.genfromtxt("./data/x_train_3dclimate.csv", delimiter=" ")
    y_train = np.genfromtxt("./data/y_train_3dclimate.csv", delimiter=" ")
    
    elevation_at_x_train = griddata(climate_kernel.train_station_elevation[:,0:2], climate_kernel.train_station_elevation[:,2],  x_train[:,0:2], method='nearest', fill_value=0.)
    dist2coast_at_x_train = griddata(climate_kernel.train_station_dist2coast[:,0:2], climate_kernel.train_station_dist2coast[:,2],  x_train[:,0:2], method='nearest', fill_value=0.)
    print("elevation:  ", elevation_at_x_train)
    print("dist2coast: ", dist2coast_at_x_train)


    print("x1 range: ", np.min(x_train[:,0]),",", np.max(x_train[:,0]))
    print("x2 range: ", np.min(x_train[:,1]),",", np.max(x_train[:,1]))
    print("t range: ", np.min(x_train[:,2]),",", np.max(x_train[:,2]))
    print("temp  range: ", np.min(y_train),",", np.max(y_train))
    print("data size:", x_train.shape)
    

    def in_bounds(v,bounds):
        if any(v<bounds[:,0]) or any(v>bounds[:,1]): return False
        return True
    

    def prior_function(theta,args):
        bounds = args["bounds"]
        hps = theta
        lamda1_hps = hps[0:3]
        lamda2_hps = hps[3:6]

        x1 = x_train
        elev = elevation_at_x_train
        dist = dist2coast_at_x_train
        e0 = hps[9]
        e1 = hps[10]
        e2 = hps[11]
        signal_variance = np.exp(e0 + 
                                 e1 * elev + 
                                 e2 * dist)

        l1 = climate_kernel.Lambda(torch.from_numpy(x1), lamda1_hps, torch.from_numpy(elev), torch.from_numpy(dist)).numpy()
        l2 = climate_kernel.Lambda(torch.from_numpy(x1), lamda2_hps, torch.from_numpy(elev), torch.from_numpy(dist)).numpy()
 
        if any(l1/l2 > 100.) or any(l1/l2 < 0.01): 
            print("PRIOR 0 --- L ratio", flush = True)
            return -np.inf

        if any(l1 > 10.) or any(l2 > 10.): 
            print("PRIOR 0 --- L too large", flush = True)
            return -np.inf


        if any(signal_variance > 1e6): 
            print("PRIOR 0 --- signal var>1e6", flush = True)
            return -np.inf

        #amplitudes = theta[317:417]
        #pis = theta[421:521]
        #zero_ampl_ind = np.where(amplitudes == 0.)[0] 
        #one_ampl_ind  = np.where(amplitudes == 1.)[0]

        if in_bounds(theta, bounds):
            prior = 0. #+ np.sum(np.log(pis[one_ampl_ind])) + np.sum(np.log(1.-pis[zero_ampl_ind]))
            print("PRIOR=", prior, flush = True)
            return prior
        else:
            print("PRIOR 0 --- out of bounds", flush = True)
            return -np.inf

    def proposal_distribution_normal(x0, hps, obj):
        cov = obj.prop_args["prop_Sigma"]
        np.save("prop_cov_"+obj.ID, cov)
        proposal_hps = np.zeros((len(x0)))
        proposal_hps = np.random.multivariate_normal(
            mean = x0, cov = cov, size = 1).reshape(len(x0))
        return proposal_hps

    
    print("inputs to the run script: ",sys.argv, flush = True)
    print("port: ", str(sys.argv[1]), flush = True)
    client = Client(str(sys.argv[1]))  #direct_to_workers=None, connection_limit=512
    client.wait_for_workers(int(sys.argv[2]))
    print("Client is ready", flush = True)
    print(datetime.datetime.now().isoformat())
    print("client received: ", client, flush = True)
    worker_info = list(client.scheduler_info()["workers"].keys())
    print("client worker info", worker_info)

    host = client.run_on_scheduler(socket.gethostname)
    port = client.scheduler_info()['services']['dashboard']
    login_node_address = "mcn@perlmutter-p1.nersc.gov" # Change this to the address/domain of your login node
    logger.info(f"ssh -N -L {port}:{host}:{port} {login_node_address}")
    log.enable("fvgp")

    print("Everything is ready to call gp2Scale after ", time.time() - start_time, flush = True)
    print("Number of GPUs per Node: ", torch.cuda.device_count())
       
        
    #with gp2Scale bumps
    hps_bounds = np.zeros((14,2)) 
    init_hyperparameters = np.zeros((14))
    


    
    hps_bounds[0:12] = np.array([-20,20])
    hps_bounds[12:14] = np.array([ ####new version
                        [1.,10], #time length scale
                        [1.,4.] #noise  
                        ])



    init_hyperparameters = np.random.uniform(low = hps_bounds[:,0],
                                             high= hps_bounds[:,1],
                                             size = len(hps_bounds))
    init_hyperparameters[0:12] = 0.

    init_hyperparameters = np.load("last_hps_backup.npy")


    if in_bounds(init_hyperparameters, hps_bounds): print("ALL IN BOUNDS", flush = True)
    else: raise Exception("INIT HPS NOT IN BOUNDS")


    print("starting hyperparameters: ")
    print(init_hyperparameters)
    

    with performance_report(filename="dask-report10Mill.html"):   
        my_gp = GPOptimizer(x_train, y_train,
                init_hyperparameters = init_hyperparameters,
                compute_device='cpu',
                gp_kernel_function=climate_kernel.kernel, 
                gp_noise_function=climate_kernel.my_noise,
                gp2Scale = True,
                gp2Scale_dask_client = client,
                gp2Scale_batch_size=15000,
                gp2Scale_linalg_mode = "sparseMINRES",
                )
        sparse.save_npz("sparse_matrix10Mill",my_gp.prior.K)

        print("GP initialized ")
        print("=============================================")
        print("=============================================")
        print("=============================================")
        print("=============================================")


        import time
        def func(hps, args):
            st = time.time()
            np.save("last_hps_backup", hps)
            result = my_gp.log_likelihood(hyperparameters=hps)
            print("f(x): ", result, "  after: ", time.time() - st)
            return result



        def write_results(obj): np.save("current_trace"+descriptor, obj.trace)

        from gpcam.gpMCMC import gpMCMC, ProposalDistribution
        coreW_hps_ind = [i for i in range(0,14)]

        axis_std_coreW = (hps_bounds[coreW_hps_ind, 1] - hps_bounds[coreW_hps_ind,0])/100.
        init_s_coreW = np.diag(axis_std_coreW**2)
        
        pd1 = ProposalDistribution(coreW_hps_ind, proposal_dist = proposal_distribution_normal,  
                                init_prop_Sigma = init_s_coreW, adapt_callable="normal", K=10, ID = "core")


        my_mcmc = gpMCMC(func, prior_function, [pd1], 
                        args={"bounds":hps_bounds})
                        
        mcmc_result = my_mcmc.run_mcmc(x0=init_hyperparameters, info=True, n_updates=500, run_in_every_iteration=write_results)
        len_trace = len(mcmc_result["x"])

        hps1 = np.median(mcmc_result["x"][481:], axis=0)
        print("FINAL HYPERPARAMETERS")
        print(hps1)


        my_gp.set_hyperparameters(hps1)
        print(my_gp.log_likelihood(my_gp.get_hyperparameters()))
        np.save("full_gp2ScaleHPS", my_gp.get_hyperparameters())
        np.save("mcmc_result"+descriptor, mcmc_result)

        
        sparsity = float(my_gp.prior.K.nnz) / float(my_gp.prior.K.shape[0]**2) #would be 0 for full sparsity aka all zeros
        sparse.save_npz("sparse_matrix10Mill",my_gp.prior.K)
        print("sparsity: ", sparsity)
      

if __name__ == "__main__":
    main()


