import torch
from scipy.stats import qmc
from pygranso.pygransoStruct import pygransoStruct
import os
import subprocess
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = subprocess.run("nvidia-smi --query-gpu=memory.free,index --format=csv,nounits,noheader | sort -nr | head -1 | awk '{ print $NF }'", shell=True, capture_output = True, text = True).stdout[0:-1]
import SALGP
import SALPlots
import SALRealities
import SALOptimizer
from datetime import datetime
import matplotlib.animation as manimation
# from torch.profiler import profile, record_function, ProfilerActivity, schedule
import logging
import torchquad
torchquad.set_up_backend("torch", data_type="float64")
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)
from matplotlib import pyplot as plt
import csv
import time

# GPU device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# global settings
global_name = "RPDynFirstPaper"
reality_class = SALRealities.RailPressureDynamicReality
reality_xmin = reality_class(device).xmin
reality_xmax = reality_class(device).xmax
num_outputs = reality_class(device).num_outputs
starting_lengthscale = 0.5*(reality_xmax-reality_xmin)
sal_crit = 0 # which output is relevant for sal
constraints = [
    [-0.6047470877,[1.0],[2.0]] # output at most 18-2sigma Maple: y * 5.1489801806121 + 14.886169231027;solve(%=18,y);                            5.1489801806121 y + 14.886169231027       0.6047470877
]

prior_dict = {'SE': {'raw_lengthscale' : {"mean": -1.0 , "std": 0.01}},
                    'c':{'raw_outputscale':{"mean": 0.5, "std": 0.1 } },
                    'noise': {'raw_noise':{"mean": -3.0, "std": 0.1 } },
                    'mean': {'raw_constant':{"mean": -constraints[0][0], "std": 0.01 } }
                    }
train_numbers = [3]
train_GP = [ True ]
seed_start = 00
nr_seeds = 10
sal_iter = 1000
quick_gp_training_iter = 30
local_changes = [5.0]
local_change = local_changes[0]
experiments = []
noises = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]
noises = [0.05, 0.1, 0.2, 0.5]
noises = [0.1]
max_dists = [0.1, 0.15, 0.2]
max_dists = [0.1]
previous_inputs_list = [2,3,4,6,7,8]
log2_int_pts = 8
# lookbacks = [0, 1, 3, 10]
# lookaheads = [0, 1, 3, 10]
# for local_change in local_changes:
for log2_int_pts in [2]:
    for seed in range(seed_start, seed_start + nr_seeds):
        for train in train_numbers:
            for noise in noises:
                for train_gp in train_GP:
                    for max_dist in max_dists:
                        experiments.append({
                            "seed": seed,
                            "name": global_name + "_TIMSPE_seed_" + str(seed) + "_maxdist_" + str(max_dist) + "_pts_" + str(log2_int_pts),
                            "quick_gp_training_iter": quick_gp_training_iter,
                            "sal_iter": sal_iter,
                            "optimization_iter": train,
                            "objective": "IMSPEequal",
                            "objective_IMSPEequal_a": reality_xmin,
                            "objective_IMSPEequal_b": reality_xmax,
                            "objective_IMSPEequal_tback": 0,
                            "objective_IMSPEequal_tforward": 10,
                            "noise": noise,
                            "model": "GPTimeConstantSEARD",
                            "local_change": local_change,
                            "train_gp": train_gp,
                            "max_dist": max_dist,
                            "log2_int_pts": log2_int_pts
                        })
                        experiments.append({
                            "seed": seed,
                            "name": global_name + "_Entropy_seed_" + str(seed) + "_maxdist_" + str(max_dist) + "_pts_" + str(log2_int_pts),
                            "quick_gp_training_iter": quick_gp_training_iter,
                            "sal_iter": sal_iter,
                            "optimization_iter": train,
                            "objective": "Entropy",
                            "noise": noise,
                            "model": "GPTimeConstantSEARD",
                            "local_change": local_change,
                            "train_gp": train_gp,
                            "max_dist": max_dist,
                            "log2_int_pts": log2_int_pts
                        })

# initialize summarizing csv file
now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
dt_string_file = now.strftime("%Y%m%d-%H%M%S")
summary_name = "summaries/" + global_name + "_" + dt_string_file + ".csv"
summary_rows = [
        "name", "objective", "local_change", "noise", "train_gp", "max_dist",
        "time", "changed_optima", "failed_optimization_runs",
        "proportion safe points (with noise)", "proportion safe points (without noise)", 
        "model_safe_precision", "model_safe_recall", "model_safe_f1", "model_safe_accuracy",
        "model_optimistic_safe_precision", "model_optimistic_safe_recall", "model_optimistic_safe_f1", "model_optimistic_safe_accuracy" ]
for i in range(num_outputs):
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.2")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.2")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.2")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.4")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.4")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.4")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.6")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.6")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.6")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_0.8")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_0.8")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_0.8")
    summary_rows.append(f"RMSE_safe_area_model_{i}_up_to_end")
    summary_rows.append(f"RMSE_all_area_model_{i}_up_to_end")
    summary_rows.append(f"NLL_eval_model_{i}_up_to_end")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.1")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.1")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.1")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.2")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.2")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.2")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.3")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.3")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.3")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.4")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.4")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.4")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.5")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.5")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.5")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.6")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.6")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.6")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.7")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.7")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.7")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.8")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.8")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.8")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_0.9")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_0.9")
    summary_rows.append(f"NLL_eval_model_{i}_at_0.9")
    summary_rows.append(f"RMSE_safe_area_model_{i}_at_end")
    summary_rows.append(f"RMSE_all_area_model_{i}_at_end")
    summary_rows.append(f"NLL_eval_model_{i}_at_end")
with open(summary_name, 'w', newline='') as myfile:
    wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
    wr.writerow(summary_rows)

for experiment in experiments:

    now = datetime.now()
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    dt_string_file = now.strftime("%Y%m%d-%H%M%S")
    experiment_name = dt_string_file + "_" + experiment["name"]

    logger = logging.getLogger('log_sal')
    file_handlers = [handler for handler in logger.handlers if isinstance(handler, logging.FileHandler)]
    for file_handler in file_handlers:
        logger.removeHandler(file_handler)
        file_handler.close()
    file_handler = logging.FileHandler('logs/'+experiment_name+'.log')
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(file_handler)
    logger.setLevel(logging.DEBUG)

    logger.info("torch version: " + torch.__version__)
    
    logger.info(f"------------------------------------")    
    logger.info(f"   {experiment_name}")
    logger.info(f"------------------------------------")    
    logger.info(str(experiment))
    logger.info(f"------------------------------------")    

    # Check if there are multiple devices (i.e., GPU cards)-
    logger.info(f"Number of GPU(s) available = {torch.cuda.device_count()}")
    if torch.cuda.is_available():
        logger.info(f"Current GPU: {torch.cuda.current_device()}")
        logger.info(f"Current GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    else:
        logger.info("PyTorch does not have access to GPU")
    logger.info(f"Available device: {device}")

    torch.manual_seed(experiment["seed"])
    import random
    random.seed(experiment["seed"])
    import numpy as np
    np.random.seed(experiment["seed"])
    reality = reality_class(device, noise=experiment["noise"], local_change=experiment["local_change"])
    reality_test = reality_class(device, noise=experiment["noise"], local_change=experiment["local_change"])

    # make initial measurements, in an intentionally smaller area than xmin/xmax 
    sampler = qmc.Sobol(d=reality.num_inputs, scramble=False)
    while reality.xdata == None or reality.xdata.shape[0] < 2**experiment["log2_int_pts"]:
        x = reality.safe_min + (reality.safe_max-reality.safe_min) * torch.tensor(sampler.random(1), device=device)
        reality.next_point(*list(x[0,:]))
        if reality.fdata[-1,0]>constraints[0][0] or reality.ydata[-1,0]>constraints[0][0]:
            reality.remove_last_point()

    # construct evaluation points, write into file, comment out generation, instead load from file
    # sampler = qmc.Sobol(d=reality.num_inputs, scramble=False)
    # # x = realityreality_test.xdata.safe_min + (reality.safe_max-reality.safe_min) * torch.tensor(sampler.random_base2(m=experiment["log2_int_pts"]), device=device)
    # patience = 0
    # meta_patience = 0
    # while reality_test.xdata == None or reality_test.xdata.shape[0] < 2048:
    #     x = reality_test.xmin + (reality_test.xmax-reality_test.xmin) * torch.tensor(sampler.random(1), device=device)
    #     reality_test.next_point(*list(x[0,:]))
    #     if reality_test.fdata[-1,0]>constraints[0][0] or reality_test.ydata[-1,0]>constraints[0][0]:
    #         reality_test.remove_last_point()
    #         patience = patience + 1
    #         print("patience: " + str(patience))
    #         if patience>30:
    #             meta_patience = meta_patience + 1
    #             for i in range(meta_patience):
    #                 reality_test.remove_last_point()
    #             meta_patience = meta_patience + 1
    #             patience = 0
    #     else:
    #         print(reality_test.xdata.shape[0])
    #         patience = 0
    #         meta_patience = max(meta_patience - 1, 0)
    f= open(r'evaldata_dynrailpressure', 'r')
    model_input = torch.tensor(eval(f.read()), device=device)[:,[1,5]]
    f.close()

    # initialize models
    gp = SALGP.SALGP(reality.xdata, reality.ydata, 
                    prior_dict=prior_dict,
                    model=experiment["model"],
                    sal_crit=sal_crit,
                    previous_inputs_list = previous_inputs_list)
    if experiment["train_gp"] == True:
        gp = gp.train_pygranso()
    if experiment["train_gp"] == "Hard":
        gp = gp.train_hard()
    gp.print_training_parameters()

    # set up optimization problem
    var_in = {"x": [reality.num_inputs]}
    optimization_iter = experiment["optimization_iter"]
    opts = pygransoStruct()
    opts.torch_device = device
    opts.globalAD = True
    opts.quadprog_info_msg = False
    opts.print_level = 0
    opts.step_tol = 1e-4
    opts.opt_tol = 1e-4
    opts.viol_ineq_tol = 1e-4
    opts.maxit = 200
    opts.ngrad = 3 # recommended if objective is smooth
    failed_optimizations_relaxation_iter_bound = 2
    failed_optimization_sigma_relaxation = 0.2

    def ObjectiveFactory(gp, t, device):
        
        if experiment["objective"]=="Entropy":
            return lambda xs: gp.NegEntropy(xs)
        
        if experiment["objective"]=="IMSPEequal":
            xmin = torch.cat((t+experiment["objective_IMSPEequal_tback"], torch.tensor(experiment["objective_IMSPEequal_a"], device=device)))
            xmax = torch.cat((t+experiment["objective_IMSPEequal_tforward"], torch.tensor(experiment["objective_IMSPEequal_b"], device=device)))
            return lambda xs: gp.IMSPEMarginalizedOverUniformDistribution(xs, xmin, xmax)
        
        if experiment["objective"]=="IMSPEgauss":
            m = torch.cat((t, torch.zeros(reality.num_inputs, device=device)))+torch.tensor(experiment["objective_IMSPEgauss_m"], device=device)
            s = torch.tensor(experiment["objective_IMSPEgauss_s"], device=device)
            return lambda xs: gp.IMSPEMarginalizedOverGaussianDistribution(xs, m, s)
        
        raise NotImplementedError()

    SAL_Optimizer = SALOptimizer.SALOptimizer(
        reality.xmin, reality.xmax, var_in, ObjectiveFactory, constraints,
        optimization_iter, failed_optimizations_relaxation_iter_bound, failed_optimization_sigma_relaxation,
        opts,device,
        max_dist=experiment["max_dist"]
        )

    # main computation
    xpts = reality.xdata.cpu().numpy()[:,1]
    ypts = reality.xdata.cpu().numpy()[:,5]
    tpts = reality.xdata.cpu().numpy()[:,0]
    ydata = reality.ydata.cpu().detach().numpy()
    fdata = reality.fdata.cpu().detach().numpy()

    # settings
    plt.rcParams["figure.figsize"] = [22.50, 10.50]
    plt.rcParams["figure.autolayout"] = True

    fig = plt.figure()

    t = torch.tensor([reality.current_time], device=device)
    SAL_Plots = SALPlots.SALPlotsTime(
        fig,
        reality.xmin.detach().cpu(), reality.xmax.detach().cpu(),
        2**experiment["log2_int_pts"], experiment["sal_iter"], 
        xpts, ypts, tpts, 
        ydata, fdata,
        reality.xdiff.detach().cpu()/15,
        gp, 
        lambda *args: reality.reality(0.0, *args), lambda *args: reality.reality(reality.current_time, *args), ObjectiveFactory(gp, t, device),
        constraints, 
        device,
        current_time = reality.current_time,
        model_input = model_input,
        fixed_reality = True)

    WriterClass = manimation.writers['ffmpeg_file']
    metadata = dict(title=experiment_name, artist='MLH', comment=experiment_name)
    moviewriter = WriterClass(fps=2.0, metadata=metadata)
    moviewriter.setup(fig, 'animations/' + experiment_name + '.mp4', dpi=100)

    # prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],profile_memory=True, with_stack=True, record_shapes=True, schedule=schedule(skip_first=2, wait=2,warmup=2,active=2,repeat=2))
    # from torch.profiler import profile, record_function, ProfilerActivity, schedule
    # with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, with_stack=True, record_shapes=True, schedule=schedule(skip_first=0, wait=2,warmup=0,active=1,repeat=2)) as prof:

    timer = time.time() 

    for i in range(experiment["sal_iter"]):
    # as long as we want to continue

        logger.info("step: " + str(i) + " init")
        # logger.info(subprocess.run("nvidia-smi", shell=True, capture_output = True, text = True).stdout)

        # optimization
        t = torch.tensor([reality.current_time + 1.0], device=device)

        # with torch.profiler.record_function("optimize"):
        newx = SAL_Optimizer.optimize(t, gp, starting_point=reality.xdata[-1,[1,5]], max_dist = experiment["max_dist"])
        if newx is None:
            logger.info("chose safe point!!!!!")
            newx = reality.safe_min + (reality.safe_max-reality.safe_min) * torch.tensor(sampler.random(1)[0,:], device=device)
        logger.info("new point chosen: " + str(newx))

        # measure
        reality.next_point(*list(newx))

        xpts = reality.xdata.cpu().numpy()[:,1]
        ypts = reality.xdata.cpu().numpy()[:,5]

        tpts = reality.xdata.cpu().numpy()[:,0]
        ydata = reality.ydata.detach().cpu().numpy()
        fdata = reality.fdata.detach().cpu().numpy()
        logger.info("outputs: " + str(ydata[-1,:]))

        # update models
        # with torch.profiler.record_function("update models"):
        if i != experiment["sal_iter"] - 1:
            gp = gp.set_train_data(reality.xdata, reality.ydata)
            if experiment["train_gp"] == True:
                gp = gp.train_adam(training_iter=experiment["quick_gp_training_iter"], lr=0.1)
            if experiment["train_gp"] == "Hard":
                gp = gp.train_hard()

        with torch.no_grad():
            # with torch.profiler.record_function("output"):
            SAL_Plots.update_points_and_models(xpts, ypts, tpts, ydata, fdata, gp, 
                                               lambda x1,x2: reality.reality(reality.current_time + 1, x1, x2), ObjectiveFactory(gp, t, device), current_time = reality.current_time)
            moviewriter.grab_frame()

        # prof.step()

    moviewriter.finish()

    timer = time.time() - timer

    def Average(lst): 
        return sum(lst) / len(lst) 
    for (i,constraint) in enumerate(constraints):
        psp_noise = (torch.sum(torch.mul(reality.ydata,torch.tensor(constraint[1],device=device))<-constraint[0]) / (reality.ydata.shape[0])).item()
        psp = (torch.sum(torch.mul(reality.fdata,torch.tensor(constraint[1],device=device))<-constraint[0]) / (reality.fdata.shape[0])).item()
        logger.info( f"proportion safe points w.r.t. constraint {i}: {psp_noise} (with noise)")
        logger.info( f"proportion safe points w.r.t. constraint {i}: {psp} (without noise)")

    logger.info( f"model_safe_precision: {[a['model_safe_precision'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_precision = Average([a['model_safe_precision'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_precision_average: {safe_precision} ")
    logger.info( f"model_safe_recall: {[a['model_safe_recall'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_recall = Average([a['model_safe_recall'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_recall_average: {safe_recall} ")
    logger.info( f"model_safe_f1: {[a['model_safe_f1'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_f1 = Average([a['model_safe_f1'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_f1_average: {safe_f1} ")
    logger.info( f"model_safe_accuracy: {[a['model_safe_accuracy'] for a in SAL_Plots.model_safe_stats_list]} ")
    safe_accuracy = Average([a['model_safe_accuracy'] for a in SAL_Plots.model_safe_stats_list])
    logger.info( f"model_safe_accuracy_average: {safe_accuracy} ")
    
    logger.info( f"model_optimistic_safe_precision: {[a['model_optimistic_safe_precision'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_precision = Average([a['model_optimistic_safe_precision'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_precision_average: {optimistic_safe_precision} ")
    logger.info( f"model_optimistic_safe_recall: {[a['model_optimistic_safe_recall'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_recall = Average([a['model_optimistic_safe_recall'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_recall_average: {optimistic_safe_recall} ")
    logger.info( f"model_optimistic_safe_f1: {[a['model_optimistic_safe_f1'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_f1 = Average([a['model_optimistic_safe_f1'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_f1_average: {optimistic_safe_f1} ")
    logger.info( f"model_optimistic_safe_accuracy: {[a['model_optimistic_safe_accuracy'] for a in SAL_Plots.model_optimistic_safe_stats_list]} ")
    optimistic_safe_accuracy = Average([a['model_optimistic_safe_accuracy'] for a in SAL_Plots.model_optimistic_safe_stats_list])
    logger.info( f"model_optimistic_safe_accuracy_average: {optimistic_safe_accuracy}")
    
    RMSE_safe20 = list(range(num_outputs))
    RMSE_all20 = list(range(num_outputs))
    NLL_20_train = list(range(num_outputs))
    NLL_20_eval = list(range(num_outputs))
    RMSE_safe40 = list(range(num_outputs))
    RMSE_all40 = list(range(num_outputs))
    NLL_40_train = list(range(num_outputs))
    NLL_40_eval = list(range(num_outputs))
    RMSE_safe60 = list(range(num_outputs))
    RMSE_all60 = list(range(num_outputs))
    NLL_60_train = list(range(num_outputs))
    NLL_60_eval = list(range(num_outputs))
    RMSE_safe80 = list(range(num_outputs))
    RMSE_all80 = list(range(num_outputs))
    NLL_80_train = list(range(num_outputs))
    NLL_80_eval = list(range(num_outputs))
    RMSE_safe = list(range(num_outputs))
    RMSE_all = list(range(num_outputs))
    NLL_all_train = list(range(num_outputs))
    NLL_all_eval = list(range(num_outputs))
    for i in range(num_outputs):
        logger.info( f"model {i}")
        name_safe = 'RMSE_safe_area_model_'+str(i)
        name_all = 'RMSE_all_area_model_'+str(i)
        name_NLL_train = 'NLL_train_'+str(i)
        name_NLL_eval = 'NLL_eval_'+str(i)
        index10 = int(len(SAL_Plots.model_errors_stats_list) * 0.1)
        index20 = int(len(SAL_Plots.model_errors_stats_list) * 0.2)
        index30 = int(len(SAL_Plots.model_errors_stats_list) * 0.3)
        index40 = int(len(SAL_Plots.model_errors_stats_list) * 0.4)
        index50 = int(len(SAL_Plots.model_errors_stats_list) * 0.5)
        index60 = int(len(SAL_Plots.model_errors_stats_list) * 0.6)
        index70 = int(len(SAL_Plots.model_errors_stats_list) * 0.7)
        index80 = int(len(SAL_Plots.model_errors_stats_list) * 0.8)
        index90 = int(len(SAL_Plots.model_errors_stats_list) * 0.9)
        logger.info( f"RMSE_safe: {[a[name_safe] for a in SAL_Plots.model_errors_stats_list]} ")
        logger.info( f"RMSE_all: {[a[name_all] for a in SAL_Plots.model_errors_stats_list]} ")
        logger.info( f"NLL train: {[a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list]} ")
        logger.info( f"NLL eval: {[a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list]} ")
        RMSE_safe20[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index20]])
        RMSE_all20[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index20]])
        NLL_20_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index20]])
        NLL_20_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index20]])
        RMSE_safe40[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index40]])
        RMSE_all40[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index40]])
        NLL_40_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index40]])
        NLL_40_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index40]])
        RMSE_safe60[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index60]])
        RMSE_all60[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index60]])
        NLL_60_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index60]])
        NLL_60_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index60]])
        RMSE_safe80[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list[:index80]])
        RMSE_all80[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list[:index80]])
        NLL_80_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list[:index80]])
        NLL_80_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list[:index80]])
        RMSE_safe[i] = Average([a[name_safe] for a in SAL_Plots.model_errors_stats_list])
        RMSE_all[i] = Average([a[name_all] for a in SAL_Plots.model_errors_stats_list])
        NLL_all_train[i] = Average([a[name_NLL_train] for a in SAL_Plots.model_errors_stats_list])
        NLL_all_eval[i] = Average([a[name_NLL_eval] for a in SAL_Plots.model_errors_stats_list])
        logger.info( f"RMSE_safe first 20%: {RMSE_safe20[i]} ")
        logger.info( f"RMSE_all first 20%: {RMSE_all20[i]} ")
        logger.info( f"NLL train first 20%: {NLL_20_train[i]} ")
        logger.info( f"NLL eval first 20%: {NLL_20_eval[i]} ")
        logger.info( f"RMSE_safe first 40%: {RMSE_safe40[i]} ")
        logger.info( f"RMSE_all first 40%: {RMSE_all40[i]} ")
        logger.info( f"NLL train first 40%: {NLL_40_train[i]} ")
        logger.info( f"NLL eval first 40%: {NLL_40_eval[i]} ")
        logger.info( f"RMSE_safe first 60%: {RMSE_safe60[i]} ")
        logger.info( f"RMSE_all first 60%: {RMSE_all60[i]} ")
        logger.info( f"NLL train first 60%: {NLL_60_train[i]} ")
        logger.info( f"NLL eval first 60%: {NLL_60_eval[i]} ")
        logger.info( f"RMSE_safe first 80%: {RMSE_safe80[i]} ")
        logger.info( f"RMSE_all first 80%: {RMSE_all80[i]} ")
        logger.info( f"NLL train first 80%: {NLL_80_train[i]} ")
        logger.info( f"NLL eval first 80%: {NLL_80_eval[i]} ")
        logger.info( f"RMSE_safe: {RMSE_safe[i]} ")
        logger.info( f"RMSE_all: {RMSE_all[i]} ")
        logger.info( f"NLL train: {NLL_all_train[i]} ")
        logger.info( f"NLL eval: {NLL_all_eval[i]} ")
        logger.info( f"xdata")
        logger.info( f"{reality.xdata.detach().cpu().numpy().tolist()}")
        logger.info( f"ydata")
        logger.info( f"{reality.ydata.detach().cpu().numpy().tolist()}")
        logger.info( f"fdata")
        logger.info( f"{reality.fdata.detach().cpu().numpy().tolist()}")


    with open(summary_name, 'a', newline='') as myfile:
        wr = csv.writer(myfile)
        row = [
            experiment_name, experiment["objective"], experiment["local_change"], experiment["noise"], experiment["train_gp"], experiment["max_dist"],
            timer,SAL_Optimizer.changed_optima,SAL_Optimizer.failed_optimization_runs,
            psp_noise, psp, 
            safe_precision, safe_recall, safe_f1, safe_accuracy,
            optimistic_safe_precision, optimistic_safe_recall, optimistic_safe_f1, optimistic_safe_accuracy
        ]
        for i in range(num_outputs):
            name_safe = 'RMSE_safe_area_model_'+str(i)
            name_all = 'RMSE_all_area_model_'+str(i)
            name_NLL_eval = 'NLL_eval_'+str(i)
            row.append(RMSE_safe20[i])
            row.append(RMSE_all20[i])
            row.append(NLL_20_eval[i])
            row.append(RMSE_safe40[i])
            row.append(RMSE_all40[i])
            row.append(NLL_40_eval[i])
            row.append(RMSE_safe60[i])
            row.append(RMSE_all60[i])
            row.append(NLL_60_eval[i])
            row.append(RMSE_safe80[i])
            row.append(RMSE_all80[i])
            row.append(NLL_80_eval[i])
            row.append(RMSE_safe[i])
            row.append(RMSE_all[i])
            row.append(NLL_all_eval[i])
            row.append(SAL_Plots.model_errors_stats_list[index10][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index10][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index10][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index20][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index20][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index20][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index30][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index30][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index30][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index40][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index40][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index40][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index50][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index50][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index50][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index60][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index60][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index60][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index70][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index70][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index70][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index80][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index80][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index80][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[index90][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[index90][name_all])
            row.append(SAL_Plots.model_errors_stats_list[index90][name_NLL_eval])
            row.append(SAL_Plots.model_errors_stats_list[-1][name_safe])
            row.append(SAL_Plots.model_errors_stats_list[-1][name_all])
            row.append(SAL_Plots.model_errors_stats_list[-1][name_NLL_eval])
        wr.writerow(row)
