import torch
from scipy.stats import qmc
from pygranso.pygransoStruct import pygransoStruct
import os
import subprocess
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = subprocess.run("nvidia-smi --query-gpu=memory.free,index --format=csv,nounits,noheader | sort -nr | head -1 | awk '{ print $NF }'", shell=True, capture_output = True, text = True).stdout[0:-1]
import SALGP
import SALPlots
import SALRealities
import SALOptimizer
from datetime import datetime
import matplotlib.animation as manimation
# from torch.profiler import profile, record_function, ProfilerActivity, schedule
import logging
import torchquad
torchquad.set_up_backend("torch", data_type="float64")
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)
from matplotlib import pyplot as plt
import csv
import time

# GPU device configuration-
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

# global settings
global_name = "ToySeasonal_wide_experiments"
reality_class = SALRealities.SimpleReality
reality_xmin = reality_class(device).xmin
reality_xmax = reality_class(device).xmax
num_outputs = reality_class(device).num_outputs
version=1
sal_crit = 0 # which output is relevant for sal
constraints = [
    [0,[1],[2]] # 0 + 1*f[1]+2*fstd[1]
]

prior_dict = {
    'SE': {'raw_lengthscale' : {"mean": [ 5.0, 0.0, 0.0 ] , "std": [ 5.0, 1.0, 1.0 ]}},
    'c':{'raw_outputscale':{"mean": 1.0, "std": 1.0 } },
    'noise': {'raw_noise':{"mean": -3.0, "std": 1.0 } },
    'mean': {'raw_constant':{"mean": 10, "std": 0.01 } }
}
train_numbers = [3]
train_GP = [ True ]
seed_start = 0
nr_seeds = 25
sal_iter = 100
quick_gp_training_iter = 30
local_changes = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
experiments = []
noises = [0.1]
log2_int_pts = 3
for local_change in local_changes:
    for seed in range(seed_start, seed_start + nr_seeds):
        for train in train_numbers:
            for noise in noises:
                for train_gp in train_GP:
                    experiments.append({
                        "seed": seed,
                        "name": global_name + "_TIMSPE_" + str(train) + "_seed_" + str(seed) + "_noise_" + str(noise) + "_local_change_" + str(local_change),
                        "quick_gp_training_iter": quick_gp_training_iter,
                        "sal_iter": sal_iter,
                        "optimization_iter": train,
                        "objective": "IMSPEequal",
                        "objective_IMSPEequal_a": reality_xmin,
                        "objective_IMSPEequal_b": reality_xmax,
                        "objective_IMSPEequal_tback": 0,
                        "objective_IMSPEequal_tforward": 10,
                        "noise": noise,
                        "model": "GPFullSEARD",
                        "local_change": local_change,
                        "train_gp": train_gp
                    })
                    experiments.append({
                        "seed": seed,
                        "name": global_name + "_Entropy_" + str(train) + "_seed_" + str(seed) + "_noise_" + str(noise) + "_local_change_" + str(local_change),
                        "quick_gp_training_iter": quick_gp_training_iter,
                        "sal_iter": sal_iter,
                        "optimization_iter": train,
                        "objective": "Entropy",
                        "noise": noise,
                        "model": "GPFullSEARD",
                        "local_change": local_change,
                        "train_gp": train_gp
                    })

# initialize summarizing csv file
now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
dt_string_file = now.strftime("%Y%m%d-%H%M%S")
summary_name = "summaries/" + global_name + "_" + dt_string_file + ".csv"
summary_rows = [
        "name", "objective", "local_change", "noise", "train_gp",
        "time", "changed_optima", "failed_optimization_runs",
        "proportion safe points (with noise)", "proportion safe points (without noise)", 
        "model_safe_precision", "model_safe_recall", "model_safe_f1", "model_safe_accuracy",
        "model_optimistic_safe_precision", "model_optimistic_safe_recall", "model_optimistic_safe_f1", "model_optimistic_safe_accuracy" ]
for i in range(num_outputs):
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.2")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.2")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.2")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.4")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.4")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.4")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.6")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.6")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.6")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.8")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.8")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.8")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_end")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_end")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_end")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.1")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.1")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.1")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.2")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.2")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.2")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.3")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.3")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.3")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.4")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.4")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.4")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.5")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.5")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.5")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.6")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.6")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.6")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.7")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.7")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.7")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.8")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.8")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.8")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.9")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.9")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.9")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_end")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_end")
    summary_rows.append(f"NLL_eval_model_{i}_at_end")
with open(summary_name, 'w', newline='') as myfile:
    wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
    wr.writerow(summary_rows)

for experiment in experiments:

    now = datetime.now()
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    dt_string_file = now.strftime("%Y%m%d-%H%M%S")
    experiment_name = dt_string_file + "_" + experiment["name"]

    # logging.basicConfig(filename='logs/'+experiment_name+'.log', filemode='w', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # logging.getLogger().setLevel(0)
    logger = logging.getLogger('log_sal')
    file_handlers = [handler for handler in logger.handlers if isinstance(handler, logging.FileHandler)]
    for file_handler in file_handlers:
        logger.removeHandler(file_handler)
        file_handler.close()
    file_handler = logging.FileHandler('logs/'+experiment_name+'.log')
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(file_handler)
    logger.setLevel(logging.DEBUG)

    logger.info("torch version: " + torch.__version__)
    
    logger.info(f"------------------------------------")    
    logger.info(f"   {experiment_name}")
    logger.info(f"------------------------------------")    
    logger.info(str(experiment))
    logger.info(f"------------------------------------")    

    # Check if there are multiple devices (i.e., GPU cards)-
    logger.info(f"Number of GPU(s) available = {torch.cuda.device_count()}")
    if torch.cuda.is_available():
        logger.info(f"Current GPU: {torch.cuda.current_device()}")
        logger.info(f"Current GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    else:
        logger.info("PyTorch does not have access to GPU")
    logger.info(f"Available device: {device}")

    torch.manual_seed(experiment["seed"])
    import random
    random.seed(experiment["seed"])
    import numpy as np
    np.random.seed(experiment["seed"])

    # set up "reality"
    reality = reality_class(device, noise=experiment["noise"], local_change=experiment["local_change"], version=version)

    # make initial measurements, in an intentionally smaller area than xmin/xmax 
    sampler = qmc.Sobol(d=reality.num_inputs, scramble=False)
    x = reality.safe_min + (reality.safe_max-reality.safe_min) * sampler.random_base2(m=log2_int_pts)
    n_init = x.shape[0]
    for i in range(n_init):
        reality.next_point(*list(x[i,:]))

    # initialize models
    gp = SALGP.SALGP(reality.xdata, reality.ydata, 
                    prior_dict=prior_dict,
                    model=experiment["model"],
                    sal_crit=sal_crit)
    if experiment["train_gp"] == True:
        gp = gp.train_pygranso()
    if experiment["train_gp"] == "Hard":
        gp = gp.train_hard()
    gp.print_training_parameters()

    # set up optimization problem
    var_in = {"x": [reality.num_inputs]}
    optimization_iter = experiment["optimization_iter"]
    opts = pygransoStruct()
    opts.torch_device = device
    opts.globalAD = True
    opts.quadprog_info_msg = False
    opts.print_level = 0
    opts.step_tol = 1e-4
    opts.opt_tol = 1e-4
    opts.viol_ineq_tol = 1e-4
    opts.maxit = 200
    opts.ngrad = 3 # recommended if objective is smooth
    failed_optimizations_relaxation_iter_bound = 2
    failed_optimization_sigma_relaxation = 0.2

    def ObjectiveFactory(gp, t, device):

        if experiment["objective"]=="Entropy":
            return lambda xs: gp.NegEntropy(xs)
        
        if experiment["objective"]=="IMSPEequal":
            xmin = torch.cat((t+experiment["objective_IMSPEequal_tback"], torch.tensor(experiment["objective_IMSPEequal_a"], device=device)))
            xmax = torch.cat((t+experiment["objective_IMSPEequal_tforward"], torch.tensor(experiment["objective_IMSPEequal_b"], device=device)))
            return lambda xs: gp.IMSPEMarginalizedOverUniformDistribution(xs, xmin, xmax)
        
        if experiment["objective"]=="IMSPEgauss":
            m = torch.cat((t, torch.zeros(reality.num_inputs, device=device)))+torch.tensor(experiment["objective_IMSPEgauss_m"], device=device)
            s = torch.tensor(experiment["objective_IMSPEgauss_s"], device=device)
            return lambda xs: gp.IMSPEMarginalizedOverGaussianDistribution(xs, m, s)
        
        raise NotImplementedError()

    SAL_Optimizer = SALOptimizer.SALOptimizer(
        reality.xmin, reality.xmax, var_in, ObjectiveFactory, constraints,
        optimization_iter, failed_optimizations_relaxation_iter_bound, failed_optimization_sigma_relaxation,
        opts,device
        )

    # main computation
    xpts = reality.xdata.cpu().numpy()[:,1]
    ypts = reality.xdata.cpu().numpy()[:,2]
    tpts = reality.xdata.cpu().numpy()[:,0]
    ydata = reality.ydata.cpu().detach().numpy()
    fdata = reality.fdata.cpu().detach().numpy()

    # settings
    plt.rcParams["figure.figsize"] = [22.50, 10.50]
    plt.rcParams["figure.autolayout"] = True

    fig = plt.figure()

    t = torch.tensor([reality.current_time], device=device)
    SAL_Plots = SALPlots.SALPlotsTime(
        fig,
        reality.xmin.detach().cpu(), reality.xmax.detach().cpu(),
        16, experiment["sal_iter"], 
        xpts, ypts, tpts, 
        ydata, fdata,
        reality.xdiff.detach().cpu()/15,
        gp, 
        lambda x1,x2: reality.reality(0.0, x1, x2), lambda x1,x2: reality.reality(reality.current_time, x1, x2), ObjectiveFactory(gp, t, device),
        constraints, 
        device,
        current_time = reality.current_time)

    # WriterClass = manimation.writers['ffmpeg']
    WriterClass = manimation.writers['ffmpeg_file']
    metadata = dict(title=experiment_name, artist='MLH', comment=experiment_name)
    moviewriter = WriterClass(fps=1.0, metadata=metadata)
    moviewriter.setup(fig, 'animations/' + experiment_name + '.mp4', dpi=100)

    # prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],profile_memory=True, with_stack=True, record_shapes=True, schedule=schedule(skip_first=2, wait=2,warmup=2,active=2,repeat=2))
    # from torch.profiler import profile, record_function, ProfilerActivity, schedule
    # with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, with_stack=True, record_shapes=True, schedule=schedule(skip_first=0, wait=2,warmup=0,active=1,repeat=2)) as prof:

    timer = time.time() 

    for i in range(experiment["sal_iter"]):
    # as long as we want to continue

        logger.info("step: " + str(i) + " init")

        # optimization
        t = torch.tensor([reality.current_time + 1.0], device=device)

        # with torch.profiler.record_function("optimize"):
        newx = SAL_Optimizer.optimize(t, gp)
        if newx is None:
            logger.info("chose safe point!!!!!")
            newx = reality.safe_min + (reality.safe_max-reality.safe_min) * sampler.random(1)[0,:]
        logger.info("new point chosen: " + str(newx))

        # measure
        reality.next_point(*list(newx))

        xpts = reality.xdata.cpu().numpy()[:,1]
        ypts = reality.xdata.cpu().numpy()[:,2]

        tpts = reality.xdata.cpu().numpy()[:,0]
        ydata = reality.ydata.detach().cpu().numpy()
        fdata = reality.fdata.detach().cpu().numpy()
        logger.info("outputs: " + str(ydata[-1,:]))

        # update models
        # with torch.profiler.record_function("update models"):
        if i != experiment["sal_iter"] - 1:
            gp = gp.set_train_data(reality.xdata, reality.ydata)
            if experiment["train_gp"] == True:
                gp = gp.train_adam(training_iter=experiment["quick_gp_training_iter"], lr=0.1)
            if experiment["train_gp"] == "Hard":
                gp = gp.train_hard()

        with torch.no_grad():
            # with torch.profiler.record_function("output"):
            SAL_Plots.update_points_and_models(xpts, ypts, tpts, ydata, fdata, gp, 
                                               lambda x1,x2: reality.reality(reality.current_time + 1, x1, x2), ObjectiveFactory(gp, t, device), current_time = reality.current_time)
            moviewriter.grab_frame()

        # prof.step()

    moviewriter.finish()

    timer = time.time() - timer

    def Average(lst): 
        return sum(lst) / len(lst) 
    for (i,constraint) in enumerate(constraints):
        psp_noise = (torch.sum(torch.mul(reality.ydata,torch.tensor(constraint[1],device=device))<-constraint[0]) / (reality.ydata.shape[0])).item()
        psp = (torch.sum(torch.mul(reality.fdata,torch.tensor(constraint[1],device=device))<-constraint[0]) / (reality.fdata.shape[0])).item()
        logger.info( f"proportion safe points w.r.t. constraint {i}: {psp_noise} (with noise)")
        logger.info( f"proportion safe points w.r.t. constraint {i}: {psp} (without noise)")

    logger.info( f"model_safe_precision: {[a['model_safe_precision'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_precision = Average([a['model_safe_precision'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_precision_average: {safe_precision} ")
    logger.info( f"model_safe_recall: {[a['model_safe_recall'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_recall = Average([a['model_safe_recall'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_recall_average: {safe_recall} ")
    logger.info( f"model_safe_f1: {[a['model_safe_f1'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_f1 = Average([a['model_safe_f1'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_f1_average: {safe_f1} ")
    logger.info( f"model_safe_accuracy: {[a['model_safe_accuracy'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_accuracy = Average([a['model_safe_accuracy'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_accuracy_average: {safe_accuracy} ")
    
    logger.info( f"model_optimistic_safe_precision: {[a['model_optimistic_safe_precision'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_precision = Average([a['model_optimistic_safe_precision'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_precision_average: {optimistic_safe_precision} ")
    logger.info( f"model_optimistic_safe_recall: {[a['model_optimistic_safe_recall'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_recall = Average([a['model_optimistic_safe_recall'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_recall_average: {optimistic_safe_recall} ")
    logger.info( f"model_optimistic_safe_f1: {[a['model_optimistic_safe_f1'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_f1 = Average([a['model_optimistic_safe_f1'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_f1_average: {optimistic_safe_f1} ")
    logger.info( f"model_optimistic_safe_accuracy: {[a['model_optimistic_safe_accuracy'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_accuracy = Average([a['model_optimistic_safe_accuracy'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_accuracy_average: {optimistic_safe_accuracy}")
    
    RMSE_safe20 = list(range(num_outputs))
    RMSE_all20 = list(range(num_outputs))
    NLL_20_train = list(range(num_outputs))
    NLL_20_eval = list(range(num_outputs))
    RMSE_safe40 = list(range(num_outputs))
    RMSE_all40 = list(range(num_outputs))
    NLL_40_train = list(range(num_outputs))
    NLL_40_eval = list(range(num_outputs))
    RMSE_safe60 = list(range(num_outputs))
    RMSE_all60 = list(range(num_outputs))
    NLL_60_train = list(range(num_outputs))
    NLL_60_eval = list(range(num_outputs))
    RMSE_safe80 = list(range(num_outputs))
    RMSE_all80 = list(range(num_outputs))
    NLL_80_train = list(range(num_outputs))
    NLL_80_eval = list(range(num_outputs))
    RMSE_safe = list(range(num_outputs))
    RMSE_all = list(range(num_outputs))
    NLL_all_train = list(range(num_outputs))
    NLL_all_eval = list(range(num_outputs))
    for i in range(num_outputs):
        logger.info( f"model {i}")
        name_safe = 'RMSE_safe_area_model_'+str(i)
        name_all = 'RMSE_all_area_model_'+str(i)
        name_NLL_train = 'NLL_train_'+str(i)
        name_NLL_eval = 'NLL_eval_'+str(i)
        index10 = int(len(SAL_Plots.model_errors_stats_list) * 0.1)
        index20 = int(len(SAL_Plots.model_errors_stats_list) * 0.2)
        index30 = int(len(SAL_Plots.model_errors_stats_list) * 0.3)
        index40 = int(len(SAL_Plots.model_errors_stats_list) * 0.4)
        index50 = int(len(SAL_Plots.model_errors_stats_list) * 0.5)
        index60 = int(len(SAL_Plots.model_errors_stats_list) * 0.6)
        index70 = int(len(SAL_Plots.model_errors_stats_list) * 0.7)
        index80 = int(len(SAL_Plots.model_errors_stats_list) * 0.8)
        index90 = int(len(SAL_Plots.model_errors_stats_list) * 0.9)
        logger.info( f"RMSE_safe: {[a[name_safe] for a in SAL_Plots.model_errors_stats_list]} ")
        logger.info( f"RMSE_all: {[a[name_all] for a in SAL_Plots.model_errors_stats_list]} ")
        logger.info( f"NLL train: {[a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list]} ")
        logger.info( f"NLL eval: {[a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list]} ")
        RMSE_safe20[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index20]])
        RMSE_all20[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index20]])
        NLL_20_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index20]])
        NLL_20_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index20]])
        RMSE_safe40[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index40]])
        RMSE_all40[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index40]])
        NLL_40_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index40]])
        NLL_40_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index40]])
        RMSE_safe60[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index60]])
        RMSE_all60[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index60]])
        NLL_60_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index60]])
        NLL_60_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index60]])
        RMSE_safe80[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index80]])
        RMSE_all80[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index80]])
        NLL_80_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index80]])
        NLL_80_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index80]])
        RMSE_safe[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list])
        RMSE_all[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list])
        NLL_all_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list])
        NLL_all_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list])
        logger.info( f"RMSE_safe first 20%: {RMSE_safe20[i]} ")
        logger.info( f"RMSE_all first 20%: {RMSE_all20[i]} ")
        logger.info( f"NLL train first 20%: {NLL_20_train[i]} ")
        logger.info( f"NLL eval first 20%: {NLL_20_eval[i]} ")
        logger.info( f"RMSE_safe first 40%: {RMSE_safe40[i]} ")
        logger.info( f"RMSE_all first 40%: {RMSE_all40[i]} ")
        logger.info( f"NLL train first 40%: {NLL_40_train[i]} ")
        logger.info( f"NLL eval first 40%: {NLL_40_eval[i]} ")
        logger.info( f"RMSE_safe first 60%: {RMSE_safe60[i]} ")
        logger.info( f"RMSE_all first 60%: {RMSE_all60[i]} ")
        logger.info( f"NLL train first 60%: {NLL_60_train[i]} ")
        logger.info( f"NLL eval first 60%: {NLL_60_eval[i]} ")
        logger.info( f"RMSE_safe first 80%: {RMSE_safe80[i]} ")
        logger.info( f"RMSE_all first 80%: {RMSE_all80[i]} ")
        logger.info( f"NLL train first 80%: {NLL_80_train[i]} ")
        logger.info( f"NLL eval first 80%: {NLL_80_eval[i]} ")
        logger.info( f"RMSE_safe: {RMSE_safe[i]} ")
        logger.info( f"RMSE_all: {RMSE_all[i]} ")
        logger.info( f"NLL train: {NLL_all_train[i]} ")
        logger.info( f"NLL eval: {NLL_all_eval[i]} ")
        logger.info( f"xdata")
        logger.info( f"{reality.xdata.detach().cpu().numpy().tolist()}")
        logger.info( f"ydata")
        logger.info( f"{reality.ydata.detach().cpu().numpy().tolist()}")
        logger.info( f"fdata")
        logger.info( f"{reality.fdata.detach().cpu().numpy().tolist()}")

    with open(summary_name, 'a', newline='') as myfile:
        wr = csv.writer(myfile)
        row = [
            experiment_name, experiment["objective"], experiment["local_change"], experiment["noise"], experiment["train_gp"],
            timer,SAL_Optimizer.changed_optima,SAL_Optimizer.failed_optimization_runs,
            psp_noise, psp, 
            safe_precision, safe_recall, safe_f1, safe_accuracy,
            optimistic_safe_precision, optimistic_safe_recall, optimistic_safe_f1, optimistic_safe_accuracy
        ]
        for i in range(num_outputs):
            name_safe = 'RMSE_safe_area_model_'+str(i)
            name_all = 'RMSE_all_area_model_'+str(i)
            name_NLL_eval = 'NLL_eval_'+str(i)
            row.append(RMSE_safe20[i])
            row.append(RMSE_all20[i])
            row.append(NLL_20_eval[i])
            row.append(RMSE_safe40[i])
            row.append(RMSE_all40[i])
            row.append(NLL_40_eval[i])
            row.append(RMSE_safe60[i])
            row.append(RMSE_all60[i])
            row.append(NLL_60_eval[i])
            row.append(RMSE_safe80[i])
            row.append(RMSE_all80[i])
            row.append(NLL_80_eval[i])
            row.append(RMSE_safe[i])
            row.append(RMSE_all[i])
            row.append(NLL_all_eval[i])
            row.append(SAL_Plots.model_errors_stats_list[index10][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index10][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index10][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index20][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index20][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index20][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index30][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index30][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index30][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index40][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index40][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index40][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index50][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index50][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index50][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index60][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index60][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index60][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index70][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index70][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index70][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index80][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index80][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index80][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index90][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index90][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index90][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[-1][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[-1][name_all])
            row.append(SAL_Plots.model_errors_stats_list[-1][name_NLL_eval])
        wr.writerow(row)
