# -*- coding: utf-8 -*-
"""modification of sdeflow_equivalent_sdes.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Tx_Yt90NRgHve--ocIXi6SGR-0ebwH0N
    associated to https://github.com/CW-Huang/sdeflow-light
"""


import time
import numpy as np
import torch
import torch.nn as nn
import sys
import os
import pandas as pd
import random
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator, FixedFormatter
from sklearn.datasets import make_swiss_roll
from netCDF4 import Dataset
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import seaborn as sns

from NN import MLP, NormalizeLogRadius, evaluate
from sde_scheme import euler_maruyama_sampler,heun_sampler,rk4_stratonovich_sampler
from own_plotting import plot_selected_inds, def_pd, pairplots, pairplots_single, \
                         preprocessing, postprocessing
from SDEs import forward_SDE,SDE,VariancePreservingSDE,PluginReverseSDE,multiplicativeNoise
from data import SwissRoll,Cauchy,Gaussian,PIV
import gc

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

DISPLAY_MAX_ROWS = 20  # number of max rows to print for a DataFrame
pd.set_option('display.max_rows', DISPLAY_MAX_ROWS)

# arguments

# Train
T0 = 1
MSGMs = [0,1]
beta_min=0.1 # new default
beta_max=20
t_eps = 1/1000  
norm_sampler = "ecdf"
norm_map = "log"
# default values from git repo
beta_min_SGM = 0.1 # default
beta_max_SGM = 20 # default

num_samples_init_max = int(1e5)
vtype = 'rademacher'
lr = 0.001 # default
print_every = 10000

# Inference
include_t0_reverse = True # for plots
num_samples = 10000
max_num_samples_for_mmd = num_samples
evalmmmd = False
first_run = True

# # Fair convergence comparison
# # num_steps_forward = 128 # cauchy
# # num_steps_forward = 64 # old default
# num_steps_forward = 16 # new default ?
# ntrain_maxs = [ 2**16, 2**10, 2**6 ]
# # iterationss = [ 2**24, 2**20, 2**16, 2**12, 2**8 ] # cauchy
# iterationss = [ 2**20, 2**16, 2**12, 2**8]
# # num_stepss_backward = [1000,100,50,10,4,2]
# # num_stepss_backward = [1024,256,64,16,4,1]
# # full CV
# num_stepss_backward = [128,32,8,2]
# nruns_mmd = 10 # full CV comparisons
# # # cheap CV
# # num_stepss_backward = [128,32,8,2]
# # nruns_mmd = 1 # cheap CV comparisons
# fair_comparison = True # comparaison SGM vs MSGM with same RAM usage and same learning time
# ssm_intT_ref = False
# first_run = True
# if first_run:
#     evalmmmd = False # 1st pass
# else:
#     evalmmmd = True # 2nd pass (long)

# Fair comparison more CV
ntrain_maxs = [ np.inf ]
iterationss = [ 2**20]
num_steps_forward = 16 # default
num_stepss_backward = [128]
nruns_mmd = 1
fair_comparison = True # comparaison SGM vs MSGM with same RAM usage and same learning time
ssm_intT_ref = False
evalmmmd = False

# # expressivity for cauchy (long to run)
# ntrain_maxs = [ np.inf ]
# iterationss = [ 2**20]
# num_steps_forward = 128
# num_stepss_backward = [128]
# nruns_mmd = 1
# fair_comparison = False 
# ssm_intT_ref = False
# beta_min=0.1
# beta_max=1
# # MSGMs = [1]

batch_sizes = [256]

# Dataset
datatype = 'swissroll'
# datatype = 'PIV'
# datatype = 'gaussian'
# datatype = 'cauchy'
normalized_data = True
mixedTimes = False 
Res=[None]
dbg = False
print("datatype", datatype)
delayed = False

match datatype:
    case 'swissroll': # Swiss roll
        dims = [2]
    case 'PIV': # vorticity and divergence from 2D PIV
        dims = [2,4,8,16,32]
        ratio = 4
        beta_max /= ratio # 20/ratio
        beta_min /= ratio
        t_eps /= ratio 
        beta_max_SGM=beta_max
        beta_min_SGM=beta_min


        few_data = True

        localized = True
    case 'gaussian': # multi-dimesnional gaussian
        dims = [2,4,8,16,32]
        
    case 'cauchy': # multi-dimesnional Cauchy
        # dims = [2,4,8,16]
        
        # dims = [2]
        # correlation = False # default
        # beta_max=0.4
        # correlation = True 
        # beta_max=2

        # dims = [8]
        dims = [4]

        correlation = True 
        beta_max=1
        beta_min=0.01

        # # beta_max_SGM=2
        # beta_max_SGM=beta_max
        # beta_min_SGM=beta_min
        
        t_eps /= 10 # 

        num_steps_forward = 128 # cauchy

    case _:
        raise ValueError("Unknown datatype: {}".format(datatype))

# # # DEBUG set:
# print('WARNING : DEBUG !!!!!!')
# iterationss = [16,8]
# num_stepss_backward = [10]
# num_steps_forward = 10
# num_samples = 10
# batch_sizes = [2]
# dbg = True

# Plots
scatter_plots = True
noising_plots = True
denoising_plots = True
save_results = True
plot_xlim = 3.0
height_seaborn_ref = 1
height_seaborn = height_seaborn_ref
ssize = height_seaborn
dpi=200
dimplot_max = 8
val_hist = plot_xlim
crop_data_plot = False
plot_crop = plot_xlim

# Load results 
justLoad = False
justLoadmmmd = False
if not first_run:
    justLoad = True
    justLoadmmmd = False
plt_show = False
plot_validate = False
print_RAM = False
log_scale_pdf = True
plot_ref_pdf = False
pdf_theor = None

if not justLoad:
    justLoadmmmd = False

if not plt_show:
    matplotlib.use("Agg")


def m_name_simu_root(sampler_name, gen_sde_name_SDE, iterations_ref, batch_size, num_steps_forward, beta_min, beta_max, ssm_intT, fair_comparison):
    name_simu_root = sampler_name + "/" \
        + gen_sde_name_SDE + "_" + str(iterations_ref) + "iteRefLearning_" \
        + str(num_samples_init) + "InitSples_" \
        + str(batch_size) + "batchSize_" \
        + str(num_steps_forward) + "stepsForw_"
    print("beta_min_SGM = " + str(beta_min_SGM))
    print("beta_min = " + str(beta_min))
    print("beta_max_SGM = " + str(beta_max_SGM))
    print("beta_max = " + str(beta_max))
    if MSGM:
        name_simu_root += \
            str(beta_min) + "beta_min" \
            + str(beta_max) + "beta_max" 
    else:
        name_simu_root += \
            str(beta_min_SGM) + "beta_min" \
            + str(beta_max_SGM) + "beta_max"
    if (premodule is not None):
        name_simu_root += "_" + premodule
    if (not (lr == 0.001)):
        name_simu_root += str(lr) + "lr"
    if (not (vtype == 'rademacher')):
        name_simu_root += "vtype=" + vtype
    if ssm_intT:
        name_simu_root += "_intLoss"
    if fair_comparison:
        name_simu_root += "_fairComp"
    return name_simu_root

# init device
if torch.cuda.is_available():
    device = 'cuda'
    print('use gpu\n')
# elif torch.backends.mps.is_available():
#     device = 'mps'
#     print('use mps\n')
else:
    device = 'cpu'
    print('use cpu\n')

if __name__ == '__main__':

    for ntrain_max in ntrain_maxs:
        mmd_SGM = torch.zeros((len(dims),len(Res),len(num_stepss_backward),len(iterationss),nruns_mmd))
        mmd_MSGM = torch.zeros((len(dims),len(Res),len(num_stepss_backward),len(iterationss),nruns_mmd))
        mmd_ref = torch.zeros((len(dims),len(Res),len(num_stepss_backward),len(iterationss),nruns_mmd))

        i_Res = -1
        for Re in Res:
            i_Res +=1
            
            i_dims = -1
            for dim in dims:
                i_dims +=1

                i_MGMM = -1
                for MSGM in MSGMs:
                    i_MGMM +=1


                    if not MSGM:
                        normalized_data = True
                        ssm_intT = False
                        premodule = None # default
                        # print('WARNING : SGM with sphericalNN !!!!!!')
                        # premodule = "NormalizeLogRadius" 
                    else:
                        normalized_data = False
                        ssm_intT = ssm_intT_ref
                        premodule = "NormalizeLogRadius" # default
                        # premodule = "PolarCoordinatesWithLogRadius" 

                    np.random.seed(0)
                    torch.manual_seed(0) 
                    random.seed(0)

                    num_samples_init = min(num_samples_init_max,iterationss[0]*batch_sizes[0])

                    ## 1. Initialize dataset
                    match datatype:
                        case 'swissroll':
                            sampler = SwissRoll()
                            normalized_data = False
                        case 'PIV':
                            sampler = PIV(dim, normalized=normalized_data, localized = localized, few_data=few_data, ntrain_max=ntrain_max)
                            log_scale_pdf = True
                            plot_xlim = 6
                            val_hist = 2*plot_xlim
                        case 'gaussian':
                            # correlation = False
                            # normalized_data = False
                            correlation = True # default
                            sampler = Gaussian(dim, normalized=normalized_data, correlation = correlation)
                            if not correlation:
                                plot_ref_pdf = True
                                pdf_theor = torch.distributions.Normal(0.0, 1.0)
                            plot_xlim = 4
                            val_hist = 2*plot_xlim

                        case 'cauchy':
                            sampler = Cauchy(dim, normalized=normalized_data, correlation = correlation)
                            crop_data_plot = True
                            log_scale_pdf = True
                            if dim == 2:
                                height_seaborn = height_seaborn_ref * 2

                            if not dbg:
                                num_samples = 100000 # to have enough points in the tails for the plots
                                evalmmmd = False
                                nruns_mmd = 1

                            if not correlation:
                                plot_xlim = 10 
                                plot_ref_pdf = True
                                scale = (1.0/50)
                                pdf_theor = torch.distributions.Cauchy(0.0, scale)
                            else:
                                if dim == 2:
                                    plot_xlim = 5 # for d=2 / warning : should depend of d : overwise we remove all far points / or separate crop and plot_xlim
                                else:
                                    plot_xlim = 10
                            plot_crop = 3*plot_xlim

                            if MSGM and dim == 2:
                                val_hist = 0.3
                            else:
                                val_hist = plot_xlim

                        case _:
                            raise ValueError("Unknown datatype: {}".format(datatype))

                    folder_results = "results"
                    directory = folder_results + "/" + sampler.name
                    if not os.path.exists(directory):
                        os.makedirs(directory)

                    with torch.no_grad():
                        xtest = sampler.sampletest(num_samples)
                        sampler.dim = xtest.shape[1]
                        std_test = xtest.std(axis=0)
                        if normalized_data:
                            std_norm = sampler.get_std()
                        else:
                            std_norm = torch.ones((xtest.shape[1]))
                        if (datatype == 'cauchy') :
                            std_test_plot = torch.ones_like(std_test) / std_norm
                        else:
                            std_test_plot = std_test

                        plt.close('all')
                        dimplot = np.min([dimplot_max,xtest.shape[1]])
                        columns_plot=range(1,1+dimplot)

                        pairplots_single(xtest, std_norm, std_test_plot, datatype, sampler.name , dimplot=dimplot, \
                                    crop_data_plot=crop_data_plot, plot_crop=plot_crop, plot_xlim=plot_xlim, plot_ref_pdf=plot_ref_pdf, \
                                    pdf_theor=pdf_theor, log_scale_pdf=log_scale_pdf, columns_plot=columns_plot, \
                                    plt_show=plt_show, dpi=dpi, height_seaborn=height_seaborn, ssize=ssize)
                        pairplots_single(sampler.sample(num_samples).to('cpu'), std_norm, std_test_plot, datatype, sampler.name + "_train", dimplot=dimplot, \
                                    crop_data_plot=crop_data_plot, plot_crop=plot_crop, plot_xlim=plot_xlim, plot_ref_pdf=plot_ref_pdf, \
                                    pdf_theor=pdf_theor, log_scale_pdf=log_scale_pdf, columns_plot=columns_plot, \
                                    plt_show=plt_show, dpi=dpi, height_seaborn=height_seaborn, ssize=ssize)

                    ## 3. Train

                    i_iterations = -1
                    for iterations_ref in iterationss:
                        i_iterations +=1

                        xtest = xtest.to(device)             

                        i_batch_size = -1
                        for batch_size_ref in batch_sizes:
                            i_batch_size +=1
                            if (ssm_intT):# for a fair comparison
                                batch_size = int(batch_size_ref/num_steps_forward) # for a fair comparison in term of RAM
                            else:  
                                batch_size = batch_size_ref
                            if (fair_comparison and MSGM):# for a fair comparison
                                ratio_ite = max([1, int(np.sqrt(sampler.dim) * num_steps_forward / 16)])
                                print('ratio_ite = ' + str(ratio_ite))
                                iterations = int(iterations_ref/ratio_ite) # for a fair comparison in term of learning time 
                                iterations = max([1,iterations])
                            else:  
                                iterations = iterations_ref
                            num_samples_init = min(num_samples_init_max,iterations*batch_size)
                            print('num_samples_init = ' + str(num_samples_init))
                        
                            # init models
                            drift_q = MLP(input_dim=sampler.dim, index_dim=1, hidden_dim=128, premodule = premodule).to(device)
                            T = torch.nn.Parameter(torch.FloatTensor([T0]), requires_grad=False)


                            with torch.no_grad():
                                if MSGM:
                                    x_init = sampler.sample(num_samples_init).to(device)
                                    inf_sde = multiplicativeNoise(x_init,beta_min=beta_min, beta_max=beta_max, \
                                                                t_epsilon=t_eps, T=T, num_steps_forward=num_steps_forward, \
                                                                device=device, estim_cst_norm_dens_r_T = False, \
                                                                norm_sampler = norm_sampler,
                                                                norm_map = norm_map, \
                                                                plot_validate = plot_validate)
                                    del x_init
                                else:
                                    inf_sde = VariancePreservingSDE(beta_min=beta_min_SGM, beta_max=beta_max_SGM, \
                                                                    t_epsilon=t_eps, T=T, num_steps_forward=num_steps_forward, \
                                                                    device=device)
                            gen_sde = PluginReverseSDE(inf_sde, drift_q, T, vtype=vtype, debias=False, ssm_intT=ssm_intT).to(device)

                            print("data = " + sampler.name )
                            print("name_SDE = " + str(inf_sde.name_SDE) )   
                            print("num_steps_forward = " + str(num_steps_forward))
                            print("beta_min = " + str(beta_min))
                            print("beta_max = " + str(beta_max))
                            print("t_eps = " + str(t_eps))     
                            print("iterations = " + str(iterations) )
                            print("iterations_ref = " + str(iterations_ref) )
                            print("batch_size = " + str(batch_size) ) 
                            print("ssm_intT = " + str(ssm_intT) )  
                            print("fair_comparison = " + str(fair_comparison) )  
                            print("premodule = " + str(premodule) )
                            print("ntrain_max = " + str(ntrain_max))

                            name_simu_root = m_name_simu_root(sampler.name, gen_sde.base_sde.name_SDE, \
                                                                iterations_ref, batch_size, num_steps_forward, \
                                                                beta_min, beta_max, ssm_intT, fair_comparison)
                            
                            if delayed:
                                print('delayed ...')
                                time.sleep(1e4)
                                # time.sleep(1e5)

                            
                            # Forward SDE
                            with torch.no_grad():
                                print('integrate forward SDE')
                                for_sde = forward_SDE(inf_sde, T)
                                xs_forward = rk4_stratonovich_sampler(for_sde, xtest.clone(), num_steps_forward,  \
                                                                    lmbd=0., keep_all_samples=True, \
                                                                    include_t0=True, norm_correction = MSGM) # sample
                                
                                preprocessing(xtest, xs_forward, num_steps_forward, name_simu_root, \
                                                noising_plots, plt_show, folder_results, val_hist, std_test_plot, device)

                            if (not justLoad):
                                # init optimizer
                                optim = torch.optim.Adam(gen_sde.parameters(), lr=lr)

                                # train
                                start_time = time.time()
                                for i in range(iterations):
                                    optim.zero_grad() # init optimizer
                                    with torch.no_grad():
                                        x = sampler.sample(batch_size).to(device) # sample data
                                    loss = gen_sde.ssm(x).mean() # forward and compute loss
                                    loss.backward() # backward
                                    optim.step() # update

                                    # print
                                    if (i == 0) or ((i+1) % print_every == 0):
                                        # elbo
                                        elbo, elbo_std = evaluate(gen_sde, x)

                                        # print
                                        elapsed = time.time() - start_time
                                        print('| iter {:6d} | {:5.2f} ms/step | loss {:8.3f} | elbo {:8.3f} | elbo std {:8.3f} '
                                            .format(i+1, elapsed*1000/print_every, loss.item(), elbo.item(), elbo_std.item()))
                                        start_time = time.time()

                                        del elbo, elbo_std
                                        gc.collect()
                                    
                                    del x
                                    # torch.mps.empty_cache()  # does not do much on MPS, but still good practice
                                    # gc.collect()
                                del loss, optim
                                gc.collect()


                            ## 4. Visualize
                            with torch.no_grad():

                                ### 4.3. Simulate SDEs
                                """
                                Simulate the generative SDE by using RK4 method
                                """
                                i_num_stepss_backward = -1
                                for num_steps_backward in num_stepss_backward:
                                    i_num_stepss_backward +=1
                                    print("Generation : num_steps_backward = " + str(num_steps_backward))
                                    # init param
                                    # num_samples = 100000

                                    # indices to visualize
                                    fig_step = int(num_steps_backward/8) #4
                                    if fig_step < 1:
                                        fig_step = 1
                                    if include_t0_reverse:
                                        inds = range(0, num_steps_backward+1, fig_step)
                                    else:
                                        inds = range(fig_step-1, num_steps_backward, fig_step)
                                    # sample and plot
                                    plt.close('all')
                                    lmbd = 0.
                                    name_simu = folder_results + "/" + name_simu_root \
                                        + str(t_eps) + "t_eps" \
                                        + str(num_steps_backward) + "stepsBack_" \
                                        + str(include_t0_reverse) + "t0infer"
                                    
                                    for i_run in range(nruns_mmd):
                                        print("Run number : " + str(i_run))
                                        if i_run > 0 :
                                            directory = "runs" + "/" + sampler.name
                                            if not os.path.exists(directory):
                                                os.makedirs(directory)
                                            name_simu = "runs/" + name_simu_root \
                                                + str(t_eps) + "t_eps" \
                                                + str(num_steps_backward) + "stepsBack_" \
                                                + str(include_t0_reverse) + "t0infer" \
                                                + "_run"+ str(i_run)
                                        
                                        if (justLoad):
                                            save_results = False
                                            xs = torch.load(name_simu + ".pt", weights_only=True)
                                        else:
                                            x_0 = gen_sde.latent_sample(num_samples, sampler.dim) # init from prior
                                            xs = rk4_stratonovich_sampler(gen_sde, x_0, num_steps_backward, lmbd=lmbd,\
                                                                        keep_all_samples=True, 
                                                                        include_t0=include_t0_reverse, 
                                                                        norm_correction = MSGM) # sample
                                            del x_0
                                            if (save_results):
                                                torch.save(xs, name_simu + ".pt")
                                        postprocessing(inds, i_dims, i_Res, i_num_stepss_backward, i_iterations, i_run, MSGM, sampler, \
                                                        xs, xtest, std_norm, std_test_plot, datatype, name_simu, dimplot, \
                                                        crop_data_plot, plot_crop, plot_xlim, plot_ref_pdf, \
                                                        pdf_theor, log_scale_pdf, columns_plot, \
                                                        scatter_plots, denoising_plots, include_t0_reverse, plt_show, dpi, height_seaborn, ssize, \
                                                        evalmmmd, justLoadmmmd, justLoad, save_results, lmbd, val_hist, device, \
                                                        mmd_ref, mmd_MSGM,mmd_SGM,max_num_samples_for_mmd)

                    ## Convergence plots (with MMD)

                    if evalmmmd:
                        # if justLoadmmmd and (not MSGM):
                        if justLoadmmmd:
                            if not MSGM:
                                print("filename = " + folder_results + "/" + name_simu_root + "_globalMMDfile_SGM_" + str(nruns_mmd) + "runs.pt")
                                mmd_SGM = torch.load(folder_results + "/" + name_simu_root + "_globalMMDfile_SGM_" + str(nruns_mmd) + "runs.pt")
                            else:
                                print("filename = " + folder_results + "/" + name_simu_root + "_globalMMDfile_MSGM_" + str(nruns_mmd) + "runs.pt")
                                mmd_MSGM = torch.load(folder_results + "/" + name_simu_root + "_globalMMDfile_MSGM_" + str(nruns_mmd) + "runs.pt") 
                                print("filename = " + folder_results + "/" + name_simu_root + "_globalMMDfile_ref_" + str(nruns_mmd) + "runs.pt")
                                mmd_ref = torch.load(folder_results + "/" + name_simu_root + "_globalMMDfile_ref_" + str(nruns_mmd) + "runs.pt") 
                        else:
                            if not MSGM:
                                torch.save(mmd_SGM, folder_results + "/" + name_simu_root + "_globalMMDfile_SGM_" + str(nruns_mmd) + "runs.pt")
                            else:
                                torch.save(mmd_MSGM, folder_results + "/" + name_simu_root + "_globalMMDfile_MSGM_" + str(nruns_mmd) + "runs.pt")
                                torch.save(mmd_ref, folder_results + "/" + name_simu_root + "_globalMMDfile_ref_" + str(nruns_mmd) + "runs.pt")

                if evalmmmd:
                    fig = plt.figure(figsize=(5,3))
                    
                    # Take square root and evaluate mean and quantiles
                    mmmd_SGM = mmd_SGM.sqrt().mean(dim=4)
                    q10mmd_SGM = mmd_SGM.sqrt().quantile(0.1,dim=4)
                    q90mmd_SGM = mmd_SGM.sqrt().quantile(0.9,dim=4)
                    mmmd_MSGM = mmd_MSGM.sqrt().mean(dim=4)
                    q10mmd_MSGM = mmd_MSGM.sqrt().quantile(0.1,dim=4)
                    q90mmd_MSGM = mmd_MSGM.sqrt().quantile(0.9,dim=4)
                    mmmd_ref = mmd_ref.sqrt().mean(dim=4)
                    q10mmd_ref = mmd_ref.sqrt().quantile(0.1,dim=4)
                    q90mmd_ref = mmd_ref.sqrt().quantile(0.9,dim=4)

                    alpha_plot = 0.2
                    range_num_stepss_backward = range(len(num_stepss_backward))
                    plt.loglog(num_stepss_backward,mmmd_SGM[i_dims,i_Res,range_num_stepss_backward,0].flatten(),label='SGM')
                    plt.fill_between(num_stepss_backward, q10mmd_SGM[i_dims,i_Res,range_num_stepss_backward,0].flatten(), \
                                                            q90mmd_SGM[i_dims,i_Res,range_num_stepss_backward,0].flatten(),
                        alpha=alpha_plot)
                    plt.loglog(num_stepss_backward,mmmd_MSGM[i_dims,i_Res,range_num_stepss_backward,0].flatten(),label='MSGM')
                    plt.fill_between(num_stepss_backward, q10mmd_MSGM[i_dims,i_Res,range_num_stepss_backward,0].flatten(), \
                                                            q90mmd_MSGM[i_dims,i_Res,range_num_stepss_backward,0].flatten(),
                        alpha=alpha_plot)
                    plt.loglog(num_stepss_backward,mmmd_ref[i_dims,i_Res,range_num_stepss_backward,0].flatten(),label='train data')
                    plt.fill_between(num_stepss_backward, q10mmd_ref[i_dims,i_Res,range_num_stepss_backward,0].flatten(), 
                                                            q90mmd_ref[i_dims,i_Res,range_num_stepss_backward,0].flatten(),
                        alpha=alpha_plot)
                    plt.legend()
                    plt.ylabel('MMD')
                    plt.xlabel('nb timesteps in backward SDE')
                    xx = num_stepss_backward
                    labels = [f'$2^{{{int(np.log2(idx))}}}$' for idx in xx]
                    ax = plt.gca()
                    ax.set_xticks(xx)
                    ax.xaxis.set_major_locator(FixedLocator(xx))
                    ax.xaxis.set_major_formatter(FixedFormatter(labels))
                    plt.tight_layout()
                    if plt_show:
                        plt.show(block=False)
                    name_fig = folder_results + "/" + name_simu_root + "_MMD_wBckWardSteps_" + str(nruns_mmd) + "runs.png" 
                    plt.savefig(name_fig)
                    if plt_show:
                        plt.pause(1)
                    plt.close(fig)
                    plt.close()
                    del fig


                    if mmd_SGM.shape[3]>1:
                        range_iterations = range(len(iterationss))
                        fig = plt.figure(figsize=(5,3))
                        plt.loglog(iterationss,mmmd_SGM[i_dims,i_Res,0,range_iterations].flatten(),label='SGM')
                        plt.fill_between(iterationss, q10mmd_SGM[i_dims,i_Res,0,range_iterations].flatten(), q90mmd_SGM[i_dims,i_Res,0,range_iterations].flatten(),
                            alpha=alpha_plot)
                        plt.loglog(iterationss,mmmd_MSGM[i_dims,i_Res,0,range_iterations].flatten(),label='MSGM')
                        plt.fill_between(iterationss, q10mmd_MSGM[i_dims,i_Res,0,range_iterations].flatten(), q90mmd_MSGM[i_dims,i_Res,0,range_iterations].flatten(),
                            alpha=alpha_plot)
                        plt.loglog(iterationss,mmmd_ref[i_dims,i_Res,0,range_iterations].flatten(),label='train data')
                        plt.fill_between(iterationss, q10mmd_ref[i_dims,i_Res,0,range_iterations].flatten(), q90mmd_ref[i_dims,i_Res,0,range_iterations].flatten(),
                            alpha=alpha_plot)
                        plt.legend()
                        plt.ylabel('MMD')
                        plt.xlabel('effective number of iterations')
                        xx = iterationss
                        labels = [f'$2^{{{int(np.log2(idx))}}}$' for idx in xx]
                        ax = plt.gca()
                        ax.set_xticks(xx)
                        ax.xaxis.set_major_locator(FixedLocator(xx))
                        ax.xaxis.set_major_formatter(FixedFormatter(labels))
                        plt.tight_layout()
                        if plt_show:
                            plt.show(block=False)
                        name_fig = folder_results + "/" + name_simu_root + "_MMD_wIte_" + str(nruns_mmd) + "runs.png" 
                        plt.savefig(name_fig)
                        if plt_show:
                            plt.pause(1)
                        plt.close(fig)
                        plt.close()
                        del fig

            if evalmmmd:
                if mmd_SGM.shape[0]>1:
                    range_dims = range(len(dims))
                    fig = plt.figure(figsize=(5,3))
                    plt.loglog(dims,mmmd_SGM[range_dims,i_Res,0,0].flatten(),label='SGM')
                    plt.fill_between(dims, q10mmd_SGM[range_dims,i_Res,0,0].flatten(), q90mmd_SGM[range_dims,i_Res,0,0].flatten(),
                        alpha=alpha_plot)
                    plt.loglog(dims,mmmd_MSGM[range_dims,i_Res,0,0].flatten(),label='MSGM')
                    plt.fill_between(dims, q10mmd_MSGM[range_dims,i_Res,0,0].flatten(), q90mmd_MSGM[range_dims,i_Res,0,0].flatten(),
                        alpha=alpha_plot)
                    plt.loglog(dims,mmmd_ref[range_dims,i_Res,0,0].flatten(),label='train data')
                    plt.fill_between(dims, q10mmd_ref[range_dims,i_Res,0,0].flatten(), q90mmd_ref[range_dims,i_Res,0,0].flatten(),
                        alpha=alpha_plot)
                    plt.legend()
                    plt.ylabel('MMD')
                    plt.xlabel('dimension')
                    if datatype == 'era5':
                        xx = dims
                        plt.xticks(ticks=xx )
                    else:
                        xx = dims
                        labels = [f'$2^{{{int(np.log2(idx))}}}$' for idx in xx]
                        ax = plt.gca()
                        ax.set_xticks(xx)
                        ax.xaxis.set_major_locator(FixedLocator(xx))
                        ax.xaxis.set_major_formatter(FixedFormatter(labels))   
                    plt.tight_layout()
                    if plt_show:
                        plt.show(block=False)
                    name_fig = folder_results + "/" + name_simu_root + "_MMD_wDim_" + str(nruns_mmd) + "runs.png" 
                    plt.savefig(name_fig)
                    if plt_show:
                        plt.pause(1)
                    plt.close(fig)
                    plt.close()
                    del fig