import sys

# terminal commands to run experiments
#nohup python Experiment_run.py 0 &
#nohup python Experiment_run.py 1 &
#nohup python Experiment_run.py 2 &
#nohup python Experiment_run.py 3 &

# TensorFlow and tf.keras
import json
import tensorflow as tf
#from tensorflow.keras.layers import *
#from tensorflow.keras.models import *
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K

# Helper libraries
import seaborn as sns
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


# From the tutorial
# https://www.machinecurve.com/index.php/2020/02/18/how-to-use-k-fold-cross-validation-with-keras/
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import KFold

from datetime import datetime
from packaging import version

# Used Tensorflow 2.8

# Parameters 
experiment_val = int(sys.argv[1]) # 0, 1, 2

string_values = 'ABCD'
#output_dir = 'Experiment_A/'
output_dir = 'Experiment_' + string_values[experiment_val] +'/'

# Parameters

num_trials = 30 
num_data = 10000 

epochs_1 = 30
epochs_2 = 5
learning_rate = 0.01
batch_size_param = 100 #100
noise_level = 0.1

resolution = 0.01

x0 = np.arange(0.,1., resolution)
x1 = np.arange(0.,1., resolution)

array_grid = np.array(np.meshgrid(x0, x1))


x0 = np.arange(0.,1., resolution)
x1 = np.arange(0.,1., resolution)

x0, x1 = np.meshgrid(x0, x1)

def f(x):
    batch_size = len(x)
    dimension  = x.shape[-1]
    if experiment_val == 0:
        r =  np.sqrt(np.sum(np.abs((x-0.5)**2),axis=1,keepdims=True))
        t = np.arctan2((x[:,:1]-0.5),(x[:,1:]-0.5))
        v0 = np.sin(30*r+1.*t)+2
    elif experiment_val == 1:
        xt = 20.*x - 10.
        v0 = np.cos(xt[:,:1])**2+np.cos(0.5*xt[:,1:])**2+np.exp(-np.sum(np.abs((xt)**2),
                                                                        axis=1,
                                                                        keepdims=True))
    elif experiment_val == 2:
        v0 = 2.+np.cos(20.*x[:,:1]-10.)*np.cos(20.*x[:,1:]-10.)

    elif experiment_val == 3:
        v0 = 2.+np.heaviside(np.sin(2.*np.pi*x[:,:1])*np.sin(2.*np.pi*x[:,1:]),0.5)
        
    return v0

def f2(x):
    batch_size = len(x)
    dimension  = x.shape[-1]
    a = (np.prod((np.heaviside(x-0.45,0.5)*np.heaviside(0.55-x,0.5)), 
        axis=1,keepdims=True))
    
    return np.multiply(f(x),(1.-a))

coordinates = np.array([x0.flatten(),x1.flatten()]).transpose()
y0 = f(coordinates).reshape(100,100)
#model.predict(coordinates).reshape(100,100)

#(train_loss_values_1,train_val_values_1,train_loss_values_2,train_val_values_2)

list_train_loss_values_1 = []
list_train_val_values_1 = []
list_train_loss_values_2 = []
list_train_val_values_2 = []
list_results_1 = []
list_results_2 = []

def cubic_spline(x):
    
    "Activation function to implement the basis functions"
    
    y0 = K.switch(tf.logical_and(tf.zeros(tf.shape(x))<=x, x<tf.ones(tf.shape(x))), 
                 x**3/6, 
                 tf.zeros(tf.shape(x))
                )
    
    y1 = K.switch(tf.logical_and(tf.ones(tf.shape(x))<=x, x<2*tf.ones(tf.shape(x))),
                 (-3.*(x-1.)**3 +3.*(x-1.)**2 + 3*(x-1.)+1.)/6.,
                 tf.zeros(tf.shape(x))
                )
    
    y2 = K.switch(tf.logical_and(tf.ones(tf.shape(x))*2<=x, x<3*tf.ones(tf.shape(x))),
                 (3*(x-2)**3 - 6*(x-2)**2 + 4. )/6.,
                 tf.zeros(tf.shape(x))
                )

    y3 = K.switch(tf.logical_and(tf.ones(tf.shape(x))*3<=x, x<4*tf.ones(tf.shape(x))),
                 ( 4. -x)**3/6.,
                 tf.zeros(tf.shape(x))
                )    
    
    y  = y0 + y1 + y2 + y3
    
    return y

def partition_w(n,c0):
    n0 = c0 + 3
    p0 = n0*n
    y = np.zeros((n,p0))
    for i in range(n):
        y[i,i*n0:(i+1)*n0] = 1.
    y = c0*y
    return y 

def partition_b(n,c0):
    n0 = c0 + 3
    p0 = n0*n
    return 3-(np.arange(0.,p0)%n0) 

def exponential_weights(m,k):
    v0 = np.zeros((m*2*k,k))
    v1 = np.concatenate([1./np.arange(1,m+1)**2,-1./np.arange(1,m+1)**2],axis=0).astype('float32')
    for i in range(k):
        v0[i*2*m:(i+1)*2*m,i] = v1
    v0 = 0.1*v0
    return v0

def partition_w2(n,r):
    d0 = float(3)
    l0 = 4*(2**np.arange(0,r))-3
    l1 = [partition_w(n,x) for x in l0]
    l2 = np.concatenate(l1,axis=1)
    return l2

def partition_b2(n,r):
    l0 = 4*(2**np.arange(0,r))-3
    l1 = [partition_b(n,x) for x in l0]
    l2 = np.concatenate(l1,axis=0)
    return l2

def exponential_w(m):
    u0 = 1./np.arange(1,m+1)**2
    u1 = -1./np.arange(1,m+1)**2
    u3 = 0.0625
    return u3*np.ravel([u0,u1],'F').astype('float32').reshape(2*m,1)

class atlas(keras.Model):
    """ """
    def __init__(self, n, m, r, k, h0):
        super(atlas, self).__init__()
        ################################################################################
        #tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)
        #tf.constant_initializer(0.)
        ################################################################################
        self.n = n
        self.m = m
        self.r = r
        self.k = k
        self.h0 = h0
        ################################################################################
        self.input_reshape = tf.keras.layers.Reshape((n,1),name='Input')
        self.a00 = tf.keras.layers.Conv1D(
                                name='A_Partition',
                                filters=4*(2**r), 
                                kernel_size=1,
                                strides=1, padding='valid',
                                data_format='channels_last', 
                                dilation_rate=1,
                                activation=cubic_spline, 
                                use_bias=True,
                                trainable=False,
                                kernel_initializer=tf.constant_initializer(partition_w(1,4*(2**r)-3)),
                                bias_initializer=tf.constant_initializer(partition_b(1,4*(2**r)-3)))
        self.a01 = tf.keras.layers.Reshape((n*4*(2**r),),name='A_Reshape')
        self.a02 = tf.keras.layers.Dense(units=k,
                                name='A_Weights',
                                use_bias=False,
                                trainable=True,
                                kernel_initializer=tf.constant_initializer(0.),
                                kernel_regularizer=None)
        if ((self.r > 0) or (self.m > 0)):
            self.out = tf.keras.layers.Add(name='Output')
        if self.r > 0:
            self.c00 = tf.keras.layers.Conv1D(
                                    name='C_Partition',
                                    filters=((4*(2**np.arange(0,r))).sum()), 
                                    kernel_size=1,
                                    strides=1, padding='valid',
                                    data_format='channels_last', 
                                    dilation_rate=1,
                                    activation= cubic_spline, 
                                    use_bias=True,
                                    trainable=False,
                                    kernel_initializer=tf.constant_initializer(partition_w2(1,r)),
                                    bias_initializer=tf.constant_initializer(partition_b2(1,r)))
            self.c01 = tf.keras.layers.Reshape((n*((4*(2**np.arange(0,r))).sum()),),name='C_Reshape')
            self.c02 = tf.keras.layers.Dense(units=k,
                                    name='C_Weights',
                                    use_bias=False,
                                    trainable=False,
                                    kernel_initializer=tf.constant_initializer(0.),
                                    kernel_regularizer=None)
        if self.m > 0:
            self.ae2 = tf.keras.layers.Dense(activation=None,
                                    units=2*m*k,
                                    name='AE_Weights',
                                    use_bias=False,
                                    trainable=True,
                                    kernel_initializer=tf.constant_initializer(0.),
                                    kernel_regularizer=None)
            self.ad0 = tf.keras.layers.Activation(tf.math.exp,name='Exponential')
            self.ad1 = tf.keras.layers.Reshape((k,2*m),name='Reshape')
            self.ad2 = tf.keras.layers.Dense(units=1,
                                            use_bias=False,
                                            trainable=False,
                                            kernel_initializer=tf.constant_initializer(exponential_w(m)))
            self.ad3 = tf.keras.layers.Reshape((k,))
        if ((self.r > 0) and (self.m > 0)):
            self.ce2 = tf.keras.layers.Dense(activation=None,
                                    units=2*m*k,
                                    name='CE_Weights',
                                    use_bias=False,
                                    trainable=False,
                                    kernel_initializer=tf.constant_initializer(0.),
                                    kernel_regularizer=None)
            self.Inner_Sum = tf.keras.layers.Add(name='Inner_Sum')
        
    def call(self, input_tensor, training=False):
        input_reshape = self.input_reshape(input_tensor)
        a00 = self.a00(input_reshape)
        a01 = self.a01(a00)
        a02 = self.a02(a01)
        if self.r == 0:
            if self.m == 0:
                out = a02
                return out
            elif self.m > 0:
                Inner_Sum = self.ae2(a01)
                Exp_Terms = self.ad3(self.ad2(self.ad1(self.ad0(Inner_Sum))))
                out = self.out([a02,Exp_Terms])
                return out
        elif self.r > 0:
            if self.m == 0:
                c00 = self.c00(input_reshape)
                c01 = self.c01(c00)
                c02 = self.c02(c01)
                out = self.out([c02,a02])
                return out
            elif self.m > 0:
                c00 = self.c00(input_reshape)
                c01 = self.c01(c00)
                c02 = self.c02(c01)
                Inner_Sum = self.Inner_Sum([self.ae2(a01),self.ce2(c01)])
                Exp_Terms = self.ad3(self.ad2(self.ad1(self.ad0(Inner_Sum))))
                out = self.out([c02,a02,Exp_Terms])
                return out
        ################################################################################
                
    def construct(self):
        self(tf.keras.layers.Input(shape=(self.n,)))
    
    def parameter_adaptive_zero(self):
        self.a02.set_weights(0.*np.array(self.a02.get_weights()))
        if self.m > 0:
            self.ae2.set_weights(0.*np.array(self.ae2.get_weights()))

    def increase_bspline_density(self):
        if self.r > 0:
            if self.m > 0:
                A = np.array(self.c02.get_weights()).reshape(1,
                                                             self.n, 
                                                             ((4*(2**np.arange(0,self.r))).sum()), 
                                                             self.k)
                B = np.array(self.a02.get_weights()).reshape(1,
                                                             self.n, 
                                                             4*2**self.r, 
                                                             self.k)
                C = np.concatenate((A,B),axis=2).reshape(1,
                                                         self.n*((4*(2**np.arange(0,self.r+1))).sum()),
                                                         self.k)
                Ae = np.array(self.ce2.get_weights()).reshape(1,
                                                              self.n, 
                                                              ((4*(2**np.arange(0,self.r))).sum()), 
                                                              2*self.m*self.k)
                Be = np.array(self.ae2.get_weights()).reshape(1,
                                                              self.n, 
                                                              4*2**self.r, 
                                                              2*self.m*self.k)
                Ce = np.concatenate((Ae,Be),axis=2).reshape(1,
                                                            self.n*((4*(2**np.arange(0,self.r+1))).sum()),
                                                            2*self.m*self.k)            
                D = atlas( self.n, self.m, self.r+1, self.k, self.h0)
                D.construct()
                D.c02.set_weights(C)
                D.ce2.set_weights(Ce)
                return D
            
            elif self.m == 0:
                A = np.array(self.c02.get_weights()).reshape(1,
                                                             self.n, 
                                                             ((4*(2**np.arange(0,self.r))).sum()), 
                                                             self.k)
                B = np.array(self.a02.get_weights()).reshape(1,
                                                             self.n, 
                                                             4*2**self.r, 
                                                             self.k)
                C = np.concatenate((A,B),axis=2).reshape(1,
                                                         self.n*((4*(2**np.arange(0,self.r+1))).sum()),
                                                         self.k)            
                D = atlas( self.n, self.m, self.r+1, self.k, self.h0)
                D.construct()
                D.c02.set_weights(C)

                return D
            
        elif self.r == 0:
            if self.m > 0:
                C = np.array(self.a02.get_weights())
                Ce = np.array(self.ae2.get_weights())
                D = atlas( self.n, self.m, self.r+1, self.k, self.h0)
                D.construct()
                D.c02.set_weights(C)
                D.ce2.set_weights(Ce) 
                return D
            
            elif self.m == 0:
                C = np.array(self.a02.get_weights())
                D = atlas( self.n, self.m, self.r+1, self.k, self.h0)
                D.construct()
                D.c02.set_weights(C)
                return D

    def decrease_bspline_density(self):
        if self.r > 1:
            if self.m > 0:
                A = np.array(self.c02.get_weights()).reshape(1,
                                                             self.n, 
                                                             ((4*(2**np.arange(0,self.r))).sum()), 
                                                             self.k)
                Ae = np.array(self.ce2.get_weights()).reshape(1,
                                                              self.n, 
                                                              ((4*(2**np.arange(0,self.r))).sum()), 
                                                              2*self.m*self.k)
                D = atlas( self.n, self.m, self.r-1, self.k, self.h0)
                D.construct()
                c0A = A[:,:,:((4*(2**np.arange(0,self.r-1))).sum()),:]
                D.c02.set_weights(c0A.reshape(1,self.n*((4*(2**np.arange(0,self.r-1))).sum()),self.k ))
                a0A = A[:,:,((4*(2**np.arange(0,self.r-1))).sum()):,:]
                D.a02.set_weights(a0A.reshape(1,self.n*4*(2**(self.r-1)),self.k ))
                ceA = Ae[:,:,:((4*(2**np.arange(0,self.r-1))).sum()),:]
                D.ce2.set_weights(ceA.reshape(1,self.n*((4*(2**np.arange(0,self.r-1))).sum()),2*self.m*self.k))
                aeA = Ae[:,:,((4*(2**np.arange(0,self.r-1))).sum()):,:]
                D.ae2.set_weights(aeA.reshape(1,self.n*4*(2**(self.r-1)),2*self.m*self.k))
                return D
            
            elif self.m == 0:
                A = np.array(self.c02.get_weights()).reshape(1,
                                                             self.n, 
                                                             ((4*(2**np.arange(0,self.r))).sum()), 
                                                             self.k)
                D = atlas( self.n, self.m, self.r-1, self.k, self.h0)
                D.construct()
                c0A = A[:,:,:((4*(2**np.arange(0,self.r-1))).sum()),:]
                D.c02.set_weights(c0A.reshape(1,self.n*((4*(2**np.arange(0,self.r-1))).sum()),self.k ))
                a0A = A[:,:,((4*(2**np.arange(0,self.r-1))).sum()):,:]
                D.a02.set_weights(a0A.reshape(1,self.n*4*(2**(self.r-1)),self.k ))
                return D
        
        elif self.r == 1:
            if self.m > 0:
                A = np.array(self.c02.get_weights())
                Ae = np.array(self.ce2.get_weights())
                D = atlas( self.n, self.m, self.r-1, self.k, self.h0)
                D.construct()
                D.a02.set_weights(A) 
                D.ae2.set_weights(Ae)
                return D
            
            elif self.m == 0:
                A = np.array(self.c02.get_weights())
                D = atlas( self.n, self.m, self.r-1, self.k, self.h0)
                D.construct()
                D.a02.set_weights(A)                 
                return D
            
        elif self.r == 0:
            print('Stop doing that!')
            
    def increase_exponential_terms(self):
        
        if self.r > 0:
            if self.m > 0:
                D = atlas( self.n, self.m+1, self.r, self.k, self.h0)
                D.construct()
                wae2 = np.array(D.ae2.get_weights()).reshape((1,self.n*4*2**self.r,self.k,2*(self.m+1)))
                wae2[:,:,:,:2*self.m] = np.array(self.ae2.get_weights()).reshape((1,
                                                                                  self.n*4*2**self.r,
                                                                                  self.k,2*self.m))
                D.ae2.set_weights(wae2.reshape((1,self.n*4*2**self.r,self.k*2*(self.m+1))))
                wce2 = np.array(D.ce2.get_weights()).reshape((1,
                                                              self.n*((4*(2**np.arange(0,self.r))).sum()),
                                                              self.k,
                                                              2*(self.m+1)))
                wce2[:,:,:,:2*self.m] = np.array(self.ce2.get_weights()).reshape((
                                                                1,
                                                                self.n*((4*(2**np.arange(0,self.r))).sum()),
                                                                self.k,
                                                                2*self.m))
                D.ce2.set_weights(wce2.reshape((1,
                                                self.n*((4*(2**np.arange(0,self.r))).sum()),
                                                self.k*2*(self.m+1))))
                D.a02.set_weights(self.a02.get_weights())
                D.c02.set_weights(self.c02.get_weights())
                return D
            
            if self.m == 0:
                D = atlas( self.n, self.m+1, self.r, self.k, self.h0)
                D.construct()
                D.a02.set_weights(self.a02.get_weights())
                D.c02.set_weights(self.c02.get_weights())
                return D
            
        elif self.r == 0:
            if self.m > 0:
                D = atlas( self.n, self.m+1, self.r, self.k, self.h0)
                D.construct()
                wae2 = np.array(D.ae2.get_weights()).reshape((1,self.n*4*2**self.r,self.k,2*(self.m+1)))
                wae2[:,:,:,:2*self.m] = np.array(self.ae2.get_weights()).reshape((1,
                                                                                  self.n*4*2**self.r,
                                                                                  self.k,2*self.m))
                D.ae2.set_weights(wae2.reshape((1,self.n*4*2**self.r,self.k*2*(self.m+1))))            
                D.a02.set_weights(self.a02.get_weights())
                return D
            
            elif self.m == 0:
                D = atlas( self.n, self.m+1, self.r, self.k, self.h0)
                D.construct()          
                D.a02.set_weights(self.a02.get_weights())
                return D
            
    def decrease_exponential_terms(self):
        
        if self.r > 0:
            if self.m > 1:
                D = atlas( self.n, self.m-1, self.r, self.k, self.h0)
                D.construct()
                wae2 = np.array(D.ae2.get_weights()).reshape((1,
                                                              self.n*4*2**self.r,
                                                              self.k,
                                                              2*(self.m-1)))
                wae2 = np.array(self.ae2.get_weights()).reshape((1,
                                                                self.n*4*2**self.r,
                                                                self.k,
                                                                2*self.m))[:,:,:,:2*(self.m-1)]
                D.ae2.set_weights(wae2.reshape((1,
                                                self.n*4*2**self.r,
                                                self.k*2*(self.m-1))))
                wce2 = np.array(D.ce2.get_weights()).reshape((1,
                                                              self.n*((4*(2**np.arange(0,self.r))).sum()),
                                                              self.k,
                                                              2*(self.m-1)))
                wce2 = np.array(self.ce2.get_weights()).reshape((
                                                                1,
                                                                self.n*((4*(2**np.arange(0,self.r))).sum()),
                                                                self.k,
                                                                2*self.m))[:,:,:,:2*(self.m-1)]
                D.ce2.set_weights(wce2.reshape((1,
                                                self.n*((4*(2**np.arange(0,self.r))).sum()),
                                                self.k*2*(self.m-1))))
                D.a02.set_weights(self.a02.get_weights())
                D.c02.set_weights(self.c02.get_weights())
                return D
            
            elif self.m == 1:
                D = atlas( self.n, self.m-1, self.r, self.k, self.h0)
                D.construct()
                D.a02.set_weights(self.a02.get_weights())
                D.c02.set_weights(self.c02.get_weights())
                return D
            
            elif self.m == 0:
                print('Stop doing that!')
            
        elif self.r == 0:
            if self.m > 1:
                D = atlas( self.n, self.m-1, self.r, self.k, self.h0)
                D.construct()
                wae2 = np.array(D.ae2.get_weights()).reshape((1,
                                                              self.n*4*2**self.r,
                                                              self.k,
                                                              2*(self.m-1)))
                wae2 = np.array(self.ae2.get_weights()).reshape((1,
                                                                 self.n*4*2**self.r,
                                                                 self.k,
                                                                 2*self.m))[:,:,:,:2*(self.m-1)]
                D.ae2.set_weights(wae2.reshape((1,self.n*4*2**self.r,self.k*2*(self.m-1))))            
                D.a02.set_weights(self.a02.get_weights())
                return D
            
            elif self.m == 1:
                D = atlas( self.n, self.m-1, self.r, self.k, self.h0)
                D.construct()            
                D.a02.set_weights(self.a02.get_weights())
                return D
            
            elif self.m == 0:
                print('Stop doing that!')


def atlas_evaluate():
    # Initialise a model
    # n,m,r,k,u0,h0
    n = 2
    m = 0
    r = 0
    k = 1
    h0 = 0.0e-7

    model = atlas( n, m, r, k, h0)
    model.construct()

    # Print model summary
    model.summary()

    train_loss_values_1 = []
    train_val_values_1  = []
    model_prior_training = []
    model_after_training = []

    for density in range(0,4):

        model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
                  loss=tf.keras.losses.mean_absolute_error,metrics=['mae'])
        #maybe move this initial evaluation outside the for loop
        initial_mae_1 = model.evaluate(train_inputs, train_output)

        initial_val_mae_1 = model.evaluate(x_test1, y_test1)

        model_prior_training = model_prior_training + [model.predict(coordinates).reshape(100,100)]

        history_1 = model.fit(train_inputs, train_output, epochs=epochs_1,verbose=1, validation_data=(x_test1, y_test1),batch_size=batch_size_param)

        train_loss_values_1 = train_loss_values_1+[initial_mae_1[-1]]+history_1.history['mae']

        train_val_values_1 = train_val_values_1 +[initial_val_mae_1[-1]]+history_1.history['val_loss']

        model_after_training = model_after_training + [model.predict(coordinates).reshape(100,100)]

        model=model.increase_bspline_density()
        model=model.increase_exponential_terms()
        model=model.increase_exponential_terms()
        model.summary()

        #learning_rate = learning_rate/(1+5*density)

    #output2 = model.predict(coordinates).reshape(100,100)
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
                  loss=tf.keras.losses.mean_absolute_error,metrics=['mae'])

    output2 = model.predict(coordinates).reshape(100,100)

    results1 = model.evaluate(x_test1, y_test1)[-1]

    model_prior_after_training = [model_prior_training,model_after_training]

    np.save(output_dir+"model_prior_after_training/Model_prior_after_training_" +str(index_value), 
            np.array(model_prior_after_training))

    # images_to_plot list
    num_rows = 1
    num_cols = 1
    interpol = 'none' # 'nearest'
    xticks = [0,0.5,1] #[0,0.25,0.5,0.75,1]
    yticks = [0,0.5,1]
    cmaps = ['viridis', 'plasma', 'inferno', 'magma', 'cividis'][2] # 'inferno'
    subplot_titles = [r'$r=0, \; M=0$',r'$r=1, \; M=2$',r'$r=2, \; M=4$',r'$r=3, \; M=6$'] #['a','b','c','d','e']

    fig, axs = plt.subplots(2, 4, sharex='col', sharey='row',figsize=(20,8))
    #fig.tight_layout()
    for row in range(2):
        for col in range(4):    
            ax = axs[row,col]
            if row==0:
                ax.set_title(subplot_titles[col],fontsize=20)
            elif row==1:
                ax.set_xlabel(r'$x_{1}$',fontsize=17)
            ax.set_aspect(1)
            ax.set_xticks(xticks)
            ax.tick_params(axis='x', which='major', labelsize=12)

            if col==0:
                ax.set_ylabel(r'$x_{2}$',fontsize=17)
                ax.set_yticks(yticks)
                ax.tick_params(axis='y', which='major', labelsize=12)
            pcm = ax.imshow(model_prior_after_training[row][col], 
                            cmap=cmaps, 
                            interpolation='none', 
                            extent=[0.,1,0.,1.],
                            origin='lower')
            #pcm.set_clim(np.min(images_to_plot),np.max(images_to_plot))
            pcm.set_clim(0.,3.)

    #cb = fig.colorbar(pcm, ax=axs, shrink=0.7, ticks=[-1, -0.5, 0, 0.5, 1, 1.5, 2], pad=0.02, aspect=10)
    cb = fig.colorbar(pcm, ax=axs, shrink=0.7, pad=0.02, aspect=10)
    #cb.set_label(r'$y(x_{1}, x_{2})$', labelpad=10, size=20)
    cb.set_label(r'Function Value', labelpad=15, size=18)
    cb.ax.tick_params(labelsize=14)

    fig.savefig(output_dir+"model_prior_after_training/Prior_After_Training_Comparison_" +str(index_value)+".png", 
                close = True, verbose = True, dpi=500,bbox_inches='tight')
    plt.close()

    ##################################################################################################
    num_rows = 1
    num_cols = 1
    interpol = 'none' # 'nearest'
    xticks = [0,0.5,1] #[0,0.25,0.5,0.75,1]
    yticks = [0,0.5,1]
    cmaps = ['viridis', 'plasma', 'inferno', 'magma', 'cividis'][2] # 'inferno'
    subplot_titles = [r'$r=0, \; M=0$',r'$r=1, \; M=2$',r'$r=2, \; M=4$',r'$r=3, \; M=6$'] #['a','b','c','d','e']

    fig, axs = plt.subplots(1, 4, sharex='col', sharey='row',figsize=(20,5))
    #fig.tight_layout()
    #for row in range(2):
    for col in range(4):    
        ax = axs[col]
        ax.set_title(subplot_titles[col],fontsize=20)
        ax.set_xlabel(r'$x_{1}$',fontsize=17)
        ax.set_aspect(1)
        ax.set_xticks(xticks)
        ax.tick_params(axis='x', which='major', labelsize=12)

        if col==0:
            ax.set_ylabel(r'$x_{2}$',fontsize=17)
            ax.set_yticks(yticks)
            ax.tick_params(axis='y', which='major', labelsize=12)
        pcm = ax.imshow(model_prior_after_training[row][col], 
                        cmap=cmaps, 
                        interpolation='none', 
                        extent=[0.,1,0.,1.],
                        origin='lower')
        #pcm.set_clim(np.min(images_to_plot),np.max(images_to_plot))
        pcm.set_clim(0.,3.)

    #cb = fig.colorbar(pcm, ax=axs, shrink=0.7, ticks=[-1, -0.5, 0, 0.5, 1, 1.5, 2], pad=0.02, aspect=10)
    cb = fig.colorbar(pcm, ax=axs, shrink=0.7, pad=0.02, aspect=10)
    #cb.set_label(r'$y(x_{1}, x_{2})$', labelpad=10, size=20)
    cb.set_label(r'Function Value', labelpad=15, size=18)
    cb.ax.tick_params(labelsize=14)
    fig.savefig(output_dir+"model_prior_after_training/Prior_After_Training_Comparison_" +str(index_value)+"_choice.png", 
                close = True, verbose = True, dpi=500,bbox_inches='tight')
    plt.close()
    
    ##################################################################################################
    y_new = f2(coordinates).reshape(100,100)

    initial_mae_2 = model.evaluate(train_inputs_2, train_output_2)

    initial_val_mae_2 = model.evaluate(x_test2, y_test2)

    history_2 = model.fit(train_inputs_2, train_output_2, epochs=epochs_2,verbose=1, validation_data=(x_test2, y_test2),batch_size=batch_size_param)

    train_loss_values_2 = [initial_mae_2[-1]]+history_2.history['mae']

    train_val_values_2 = [initial_val_mae_2[-1]]+history_2.history['val_loss']

    output3 = model.predict(coordinates).reshape(100,100)

    results2 = model.evaluate(x_test2, y_test2)[-1]
    ##################################################################################################

    images_to_plot = [y0,output2,f2(coordinates).reshape(100,100),output3,np.abs(output3-output2)]

    np.save(output_dir+"memory_retention_visualisation/images_to_plot_memory_retention_" +str(index_value), 
            np.array(model_prior_after_training))

    # images_to_plot list
    num_rows = 1
    num_cols = 1
    interpol = 'none' # 'nearest'
    xticks = [0,0.5,1] #[0,0.25,0.5,0.75,1]
    yticks = [0,0.5,1]
    cmaps = ['viridis', 'plasma', 'inferno', 'magma', 'cividis'][2] # 'inferno'
    subplot_titles = ['1st Target','1st Output','2nd Target','2nd Output','Difference'] #['a','b','c','d','e']

    fig, axs = plt.subplots(1, 5, sharex='col', sharey='row',figsize=(20,4))
    #fig.tight_layout()

    for col in range(5):    
        ax = axs[col]
        ax.set_title(subplot_titles[col],fontsize=20)
        ax.set_aspect(1)
        ax.set_xlabel(r'$x_{1}$',fontsize=17)
        ax.set_xticks(xticks)
        ax.tick_params(axis='x', which='major', labelsize=12)

        if col==0:
            ax.set_ylabel(r'$x_{2}$',fontsize=17)
            ax.set_yticks(yticks)
            ax.tick_params(axis='y', which='major', labelsize=12)
        pcm = ax.imshow(images_to_plot[col], cmap=cmaps, interpolation='none', extent=[0.,1,0.,1.],origin='lower')
        #pcm.set_clim(np.min(images_to_plot),np.max(images_to_plot))
        pcm.set_clim(0.,3.)

    #cb = fig.colorbar(pcm, ax=axs, shrink=0.7, ticks=[-1, -0.5, 0, 0.5, 1, 1.5, 2], pad=0.02, aspect=10)
    cb = fig.colorbar(pcm, ax=axs, shrink=0.7, pad=0.02, aspect=10)
    #cb.set_label(r'$y(x_{1}, x_{2})$', labelpad=10, size=20)
    cb.set_label(r'Function Value', labelpad=15, size=18)
    cb.ax.tick_params(labelsize=14)

    fig.savefig(output_dir+"memory_retention_visualisation/memory_retention_visualisation"+str(index_value)+".png", 
                close = True, verbose = True, dpi=500,bbox_inches='tight')
    plt.close()
    
    return (train_loss_values_1,train_val_values_1,train_loss_values_2,train_val_values_2, results1, results2)


#index_value = 0
for index_value in range(num_trials):
    train_inputs = np.random.rand(num_data,2)

    train_output =  f(train_inputs) 

    #train_output = train_output + noise_level*np.random.randn(*train_output.shape)

    # Test on First task
    ##################################################################################
    x_test1 = np.random.rand(num_data,2)
    y_test1 = f(x_test1) 

    #y_test1 = y_test1 + noise_level*np.random.randn(*y_test1.shape)

    #results1 = model.evaluate(x_test1, y_test1)

    train_inputs_2 = 0.10*np.random.rand(num_data,2)+0.45

    train_output_2 = np.ones(num_data)*0. 
    train_output_2 = train_output_2 + noise_level*np.random.randn(*train_output_2.shape)

    # Test on Second task
    ##################################################################################
    x_test2 = np.random.rand(num_data,2)
    y_test2 = f2(x_test2)
    y_test2 = y_test2 + noise_level*np.random.randn(*y_test2.shape)
    
    data_values = atlas_evaluate()
    
    list_train_loss_values_1.append(data_values[0])
    list_train_val_values_1.append(data_values[1])
    list_train_loss_values_2.append(data_values[2])
    list_train_val_values_2.append(data_values[3])
    list_results_1.append(data_values[4])
    list_results_2.append(data_values[5])

    
np.save(output_dir+'list_train_loss_values_1.npy', np.array(list_train_loss_values_1))
np.save(output_dir+'list_train_val_values_1.npy', np.array(list_train_val_values_1))

np.save(output_dir+'list_train_loss_values_2.npy', np.array(list_train_loss_values_2))
np.save(output_dir+'list_train_val_values_2.npy', np.array(list_train_val_values_2))

np.save(output_dir+'results_1.npy', np.array(list_results_1))
np.save(output_dir+'results_2.npy', np.array(list_results_2))


mean_t1_mae = np.mean(np.load(output_dir+'list_train_loss_values_1.npy'),axis=0)
stdv_t1_mae = np.std(np.load(output_dir+'list_train_loss_values_1.npy'),axis=0)

mean_t1_mae_val = np.mean(np.load(output_dir+'list_train_val_values_1.npy'),axis=0)
stdv_t1_mae_val = np.std(np.load(output_dir+'list_train_val_values_1.npy'),axis=0)

mean_t2_mae = np.mean(np.load(output_dir+'list_train_loss_values_2.npy'),axis=0)
stdv_t2_mae = np.std(np.load(output_dir+'list_train_loss_values_2.npy'),axis=0)

mean_t2_mae_val = np.mean(np.load(output_dir+'list_train_val_values_2.npy'),axis=0)
stdv_t2_mae_val = np.std(np.load(output_dir+'list_train_val_values_2.npy'),axis=0)


fig0 = plt.figure(figsize=(8,4))
x = np.arange(124)

plt.plot(x,
         mean_t1_mae, 
         linewidth= 1.,
         color=sns.color_palette('tab10')[0], 
         label="Training MAE")

plt.fill_between(x,mean_t1_mae - stdv_t1_mae,
                 mean_t1_mae + stdv_t1_mae, color=sns.color_palette('tab10')[0], alpha=0.3)

plt.plot(x, mean_t1_mae_val, 
         linewidth= 1.,
         color=sns.color_palette('tab10')[1], 
         label="Validation MAE")

plt.fill_between(x, mean_t1_mae_val - stdv_t1_mae_val,
                 mean_t1_mae_val + stdv_t1_mae_val, color=sns.color_palette('tab10')[1], alpha=0.3)

for index_val in range(1,4):
    plt.vlines(31*index_val,
               ymin=0,
               ymax=3,
               color=sns.color_palette("ch:start=.1,rot=-.5", 
                                       as_cmap=True)(0.1+0.15*index_val),
              label=r"$r="+str(index_val)+", \; M = "+str(2*index_val)+"$",
              linestyles='dashed')

legend = plt.legend()
plt.ylim(0,2.1)
frame = legend.get_frame()
frame.set_facecolor('white')

fig0.gca().set(xlabel="Epochs", ylabel="MAE")
fig0.savefig(output_dir+"training_and_validation_loss_Task_1.png", dpi=500, bbox_inches = 'tight')

plt.close()

fig0 = plt.figure(figsize=(8,4))
x = np.arange(6)

plt.plot(x,
         mean_t2_mae, 
         linewidth= 1.,
         color=sns.color_palette('tab10')[0], 
         label="Training MAE")

plt.fill_between(x,mean_t2_mae - stdv_t2_mae,
                 mean_t2_mae + stdv_t2_mae, color=sns.color_palette('tab10')[0], alpha=0.3)

plt.plot(x, mean_t2_mae_val, 
         linewidth= 1.,
         color=sns.color_palette('tab10')[1], 
         label="Validation MAE")

plt.fill_between(x, mean_t2_mae_val - stdv_t2_mae_val,
                 mean_t2_mae_val + stdv_t2_mae_val, color=sns.color_palette('tab10')[1], alpha=0.3)


legend = plt.legend()
plt.ylim(0,2.1)
frame = legend.get_frame()
frame.set_facecolor('white')

fig0.gca().set(xlabel="Epochs", ylabel="MAE")
fig0.savefig(output_dir+"training_and_validation_loss_Task_2.png", dpi=500, bbox_inches = 'tight')

plt.close()

result_dict = {}
avr0 = np.load(output_dir+'results_1.npy')
result_dict['task_1'] = str(np.mean(avr0,axis=0))+'('+str( np.std(avr0,axis=0))+')'

avr1 = np.load(output_dir+'results_2.npy')
result_dict['task_2'] = str(np.mean(avr1,axis=0))+'('+str( np.std(avr1,axis=0))+')'

with open(output_dir+"test_results_std_dev.json", "w") as write_file:
    json.dump(result_dict, write_file, indent=4)
