import datetime
import numpy as np
from gpcam import GPOptimizer
from gpcam import gpMCMC
import sys
from dask.distributed import Client
import torch
from scipy.interpolate import griddata
from dask.distributed import performance_report
from loguru import logger as log
import time
from scipy.stats import binom
from scipy.stats import beta
import scipy
import scipy.sparse as sparse
import socket
from distributed.scheduler import logger
import climate_kernel
from imate import logdet as imate_logdet


def main():
    import time
    start_time = time.time()
    descriptor = "_10Mill"
    wend = 10

    x_train = np.genfromtxt("./data/x_train_3dclimate.csv", delimiter=" ")
    x_test = np.genfromtxt("./data/x_test_3dclimate.csv", delimiter=" ")
    y_train = np.genfromtxt("./data/y_train_3dclimate.csv", delimiter=" ")
    y_test = np.genfromtxt("./data/y_test_3dclimate.csv", delimiter=" ")
    

    print("x1 range: ", np.min(x_train[:,0]),",", np.max(x_train[:,0]))
    print("x2 range: ", np.min(x_train[:,1]),",", np.max(x_train[:,1]))
    print("t range: ", np.min(x_train[:,2]),",", np.max(x_train[:,2]))
    print("temp  range: ", np.min(y_train),",", np.max(y_train))
    print("data size:", x_train.shape)
    

    
    print("inputs to the run script: ",sys.argv, flush = True)
    print("port: ", str(sys.argv[1]), flush = True)
    client = Client(str(sys.argv[1]))  #direct_to_workers=None, connection_limit=512
    client.wait_for_workers(int(sys.argv[2]))
    print("Client is ready", flush = True)
    print(datetime.datetime.now().isoformat())
    print("client received: ", client, flush = True)
    worker_info = list(client.scheduler_info()["workers"].keys())
    print("client worker info", worker_info)

    host = client.run_on_scheduler(socket.gethostname)
    port = client.scheduler_info()['services']['dashboard']
    login_node_address = "mcn@perlmutter-p1.nersc.gov" # Change this to the address/domain of your login node
    logger.info(f"ssh -N -L {port}:{host}:{port} {login_node_address}")
    log.enable("fvgp")

    print("Everything is ready to call gp2Scale after ", time.time() - start_time, flush = True)
    print("Number of GPUs per Node: ", torch.cuda.device_count())
       
    init_hyperparameters = np.load("pred_hps.npy")
        
    print("starting hyperparameters: ")
    print(init_hyperparameters)
    

    with performance_report(filename="dask-report10Mill.html"):   
        my_gp = GPOptimizer(x_train,y_train,
                init_hyperparameters = init_hyperparameters,
                compute_device='cpu',
                gp_kernel_function=climate_kernel.kernel, 
                gp_noise_function=climate_kernel.my_noise,
                gp2Scale = True,
                gp2Scale_dask_client = client,
                gp2Scale_batch_size=15000,
                gp2Scale_linalg_mode = "sparseMINRES",
                )

        print("GP initialized ")
        print("=============================================")
        print("=============================================")
        print("=============================================")
        print("=============================================")

        #mean_at_x_test = mean3d(x_test, init_hyperparameters)
        mean_at_x_test = np.zeros((len(x_test))) + my_gp.prior.m[0]
        print(mean_at_x_test)


        def calc_mean(x, mean_chunk, k, KVinvY, hps):
            return mean_chunk + k.T @ KVinvY


        def calculate_sparse_minres(KV, vec, info=False):
            logger.info("calculate_sparse_minres")
            st = time.time()
            if info: logger.info("MINRES solve in progress ...")
            if np.ndim(vec) == 1: vec = vec.reshape(len(vec), 1)
            res = np.zeros(vec.shape)
            for i in range(vec.shape[1]):
                res[:, i], exit_code = sparse.linalg.minres(KV.tocsc(), vec[:, i], rtol=0.01)
                if exit_code == 1: warnings.warn("MINRES not successful")
            if info: logger.info("MINRES compute time: {} seconds: ",
                                 time.time() - st)
            return res


        def calc_st_dev(x, kxx, k, KV, hps):
            st = time.time()
            print("st dev started", flush = True)
            import scipy.sparse as sparse
            KVinvk = calculate_sparse_minres(KV, k, info = True)
            res = np.diag(kxx - k.T @ KVinvk)
            print("std dev time: ", time.time() - st, flush = True)
            return np.sqrt(res)


        ###make chunks
        chunk_size = 10
        def divide_chunks(l, n):
            for i in range(0, len(l), n):
                yield l[i:i + n]
        tasks = list(divide_chunks(x_test, chunk_size))
        prior_mean_chunks = list(divide_chunks(mean_at_x_test, chunk_size))
        print("Number of tasks: ", len(tasks), " size of tasks: ", len(tasks[0]))
        print("TASKS:")
        print(tasks)
        print("=================================")

        ###get k(x,X)
        k = []
        for task in tasks:
            k.append(client.submit(climate_kernel.kernel, x_train, task, init_hyperparameters))
        ks = client.gather(k)


        ###get posterior means
        means = []
        for i in range(len(tasks)):
            print("mean ", i, " of ", len(tasks), " len(task): ", len(tasks[i]), " len(mean chunk: )", len(prior_mean_chunks[i]), " len(ks):", len(ks[i]), flush = True)
            means.append(client.submit(calc_mean, tasks[i], prior_mean_chunks[i], ks[i], my_gp.marginal_density.KVinvY, init_hyperparameters))
        means = client.gather(means)
        print("MEANS:")
        print(means)
        print("===================================")
        mean = np.concatenate(means)
        np.save("posterior_mean", mean)
        print("POSTERIOR MEAN WRITTEN")
        non_nan_ind = np.argwhere(~np.isnan(y_test))

        def rmse(x_test, y_test, post_mean):
            v1 = y_test.reshape(len(y_test))
            v2 = post_mean.reshape(len(v1))
            return np.sqrt(np.sum((v1 - v2) ** 2) / len(v1))


        print("RMSE = ", rmse(x_test[non_nan_ind], y_test[non_nan_ind], mean[non_nan_ind]))


        ###get k(x,x)
        kxxs = []
        for task in tasks:
            kxxs.append(client.submit(climate_kernel.kernel,task, task, init_hyperparameters))
        kxxs = client.gather(kxxs)

        ###get posterior covariance
        KV = my_gp.marginal_density.KVlinalg.KV
        #sparse.save_npz("sparse_matrix1Mill",KV)
        print("SIZE OF KV: ", (KV.data.nbytes + KV.indptr.nbytes + KV.indices.nbytes)/1e9, "GBytes")
        big_future = client.scatter(KV, broadcast = True)
        print("scattered", flush = True)
        std_devs = []
        for i in range(len(tasks)):
            print("std dev ", i, " of ", len(tasks), flush = True)
            std_devs.append(client.submit(calc_st_dev, tasks[i], kxxs[i], ks[i], big_future, init_hyperparameters, workers = worker_info[i]))
        std_devs = client.gather(std_devs)
        print("STD DEV: ",std_devs)
        std_dev = np.concatenate(std_devs)
        print("===================================")
        np.save("posterior_st_dev", std_dev)
        print("POSTERIOR St. Dev. WRITTEN")


        from scipy.stats import norm
        def _crps_s(x, mu, sigma):
            res = abs(sigma * ((1. / np.sqrt(np.pi))
                               - 2. * norm.pdf((x - mu) / sigma)
                               - (((x - mu) / sigma) * (2. * norm.cdf((x - mu) / sigma) - 1.))))
            return np.mean(res), np.sqrt(np.var(res))

        def crps(x_test, y_test, post_mean, post_var):
            mean = post_mean
            sigma = post_var
            r = _crps_s(y_test, mean, sigma)
            return r


        non_nan_ind = np.argwhere(~np.isnan(y_test))

        print("RMSE = ", rmse(x_test[non_nan_ind], y_test[non_nan_ind], mean[non_nan_ind]))
        print("CRPS = ", crps(x_test[non_nan_ind], y_test[non_nan_ind], mean[non_nan_ind], std_dev[non_nan_ind]))

        print("END")



if __name__ == "__main__":
    main()


