import matplotlib.pyplot as plt
import numpy as np
import argparse
from datamodule import data_points
from interpolator import OnlineInterpolator, Environment, BatchInterpolator
from models import LinearModel, ZeroOrderHoldModel
import sympy as sp
import time

import warnings
warnings.filterwarnings("ignore")

# CONDITIONAL DECORATOR
def conditional_decorator(flag,decorator):
    def decorate(foo):
        return decorator(foo) if flag else foo
    return decorate

# TIMEIT DECORATOR
# timeitflag = True
timeitflag = False
def timeit(foo):
    def wrapper(*args,**kwargs):
        time_start = time.time()
        foo_output = foo(*args,**kwargs)
        time_end = time.time()
        if isinstance(foo_output,OnlineInterpolator):
            ### execution time per interpolated piece
            output = (time_end - time_start) / len(foo_output.spline.pieces)
        else:
            output = time_end - time_start
        print(output, foo.__name__)
        return foo_output
    return wrapper


# AUXILIARY
def construct_A(u,rho,_flag_numeric=False):
    A = []
    for row in range(2*rho-1):
        A_row = []
        for column in range(2*rho):
            prod = 1
            for l in range(row):
                prod *= (column-l)
            A_row += [prod * u**(column-row)]
        A += [A_row]
    A += [[0 for _ in range(2*rho)]]

    if _flag_numeric:
        return np.array(A)
    else:
        return sp.Matrix(A)

def construct_B(rho,_flag_numeric=False):
    B = [0 for _ in range(2*rho-1)] + [1]
    if _flag_numeric:
        return np.array(B)
    else:
        return sp.Matrix(B)

def construct_M(A,B,_flag_numeric=False):
    M = []
    M += [B]
    if _flag_numeric:
        number_rows = sp.shape(B)[0]
    else:
        number_rows = B.shape[0]

    for column in range(1,number_rows):
        M_column = B
        for i in range(number_rows,number_rows-column,-1): 
            M_column = A[f'{i}'] @ M_column
        
        M = [M_column] + M

    if _flag_numeric:
        return np.array(M)
    else:
        return sp.BlockMatrix(M).as_explicit()

def row_normalization(M):
    N = []
    for row in range(M.shape[0]):
        N += [ M[row,:] / sum(M[row,:]) ]
    return np.array(N)

def update_moving_average(t,mu_,x):
    return (t-1)/t * mu_ + x/t

def update_moving_variance(t,mu,mu_,var_,x):
    return 1/t * ( (x-mu)**2 - 2*(mu-mu_)*np.sqrt( (t-1)*var_ ) + (t-1)*var_ )

def compute_cumulative_norm_moments(interpolator):
    cumulative_coefficients_norm_avg = []
    cumulative_coefficients_norm_std = []

    cnorm = np.linalg.norm( interpolator.spline.pieces[0].coefficients )
    cnorm_avg = cnorm
    cnorm_var = 0

    for t in range(1,len(interpolator.spline.pieces)):
        cumulative_coefficients_norm_avg.append(cnorm_avg)
        cumulative_coefficients_norm_std.append( np.sqrt(cnorm_var) )
        
        cnorm = np.linalg.norm( interpolator.spline.pieces[t].coefficients )  

        cnorm_avg_ = cnorm_avg
        cnorm_avg = update_moving_average(t,cnorm_avg_,cnorm)
        cnorm_var = update_moving_variance(t,cnorm_avg,cnorm_avg_,cnorm_var,cnorm)
    
    mean = np.array(cumulative_coefficients_norm_avg).squeeze()
    std = np.array(cumulative_coefficients_norm_std).squeeze()

    return mean,std

def compute_cumulative_cost_moments(interpolator,x,y):
    cumulative_cost_avg = []
    cumulative_cost_std = []

    env = Environment(interpolator.eta,interpolator.rho)

    data = {'time stamps':x[0], 'signal values':y[0]}
    loss = env.loss(interpolator.spline,data,avg=False)
    roughness = env.piece_roughness(interpolator.spline.pieces[0])
    cost = loss + interpolator.eta * roughness

    cost_avg = cost
    cost_var = 0

    for t in range(1,len(interpolator.spline.pieces)):
        cumulative_cost_avg.append( cost_avg )
        cumulative_cost_std.append( np.sqrt(cost_var) )

        data = {'time stamps':x[t], 'signal values':y[t]}
        loss = env.loss(interpolator.spline,data,avg=False)
        roughness = env.piece_roughness(interpolator.spline.pieces[t])
        cost = loss + interpolator.eta * roughness

        cost_avg_ = cost_avg
        cost_avg = update_moving_average(t,cost_avg_,cost)
        cost_var = update_moving_variance(t,cost_avg,cost_avg_,cost_var,cost)

    mean = np.array(cumulative_cost_avg).squeeze()
    std = np.array(cumulative_cost_std).squeeze()

    return mean,std

## computes the total cost of a interpolator reconstruction
def compute_total_cost(interpolator,x,y):
    env = Environment(interpolator.eta,interpolator.rho)

    data = {'time stamps':x[0], 'signal values':y[0]}
    loss = env.loss(interpolator.spline,data,avg=False)
    roughness = env.piece_roughness(interpolator.spline.pieces[0])
    cost = loss + interpolator.eta * roughness

    for t in range(1,len(interpolator.spline.pieces)):
        data = {'time stamps':x[t], 'signal values':y[t]}
        loss = env.loss(interpolator.spline,data,avg=False)
        roughness = env.piece_roughness(interpolator.spline.pieces[t])
        cost += loss + interpolator.eta * roughness

    return cost

def add_curve(mean,std,color='k',linestyle='-',alpha=0.1,label='',marker='',flag_std=True):
    iter = np.arange(1,len(mean)+1)
    ## plot in semilog scale
    plt.semilogx(iter,mean,color=color,linestyle=linestyle,marker=marker,label=label)
    if flag_std:
        plt.fill_between(iter,mean+std,mean-std,facecolor=color,linestyle=linestyle,alpha=alpha)


# LOOKAHEAD
@conditional_decorator(timeitflag,timeit)
def lookahead_interpolation(eta,rho,x,y,delay=1):
    interpolator = OnlineInterpolator(2*rho-1,2*(rho-1),rho,eta)
    flag_stable = True
    for i in range(len(x)):
        data = {}
        data['time stamps'] = x[i:i+delay]
        data['signal values'] = y[i:i+delay]
        if flag_stable:
            flag_stable = interpolator.forward(data)
        else:
            break
    return interpolator

# ZERO-ORDER HOLD PREDICTION
@conditional_decorator(timeitflag,timeit)
def zoh_interpolation(eta,rho,x,y,horizon=1):
    interpolator = OnlineInterpolator(2*rho-1,2*(rho-1),rho,eta)
    predictor = ZeroOrderHoldModel(horizon=horizon)
    flag_stable = True
    for i in range(len(x)):
        data = {}
        predicted_time_stamps, hold_values = predictor.forecast(x[i],y[i])
        data['time stamps'] = np.concatenate(( np.array([x[i]]),predicted_time_stamps)).squeeze()
        data['signal values'] = np.concatenate((np.array([y[i]]),hold_values)).squeeze()
        if flag_stable:
            flag_stable = interpolator.forward(data)
        else:
            break
    
    return interpolator

# LINEAR MODEL PREDICTION (PRETRAINED -- OFFLINE)
@conditional_decorator(timeitflag,timeit)
def pretrained_predicted_interpolation(eta,rho,x,y,training_y,lag=1,horizon=1):
    interpolator = OnlineInterpolator(2*rho-1,2*(rho-1),rho,eta)
    predictor = LinearModel(lag=lag,horizon=horizon)
    predictor.train_model(training_y)
    flag_stable = True
    for i in range(len(x)):
        data = {}
        lagged_values = np.pad( y[:i+1] , (lag,0),'constant',constant_values=0 )[-lag:]
        predicted_time_stamps, predicted_values = predictor.forecast(x[i],lagged_values)
        data['time stamps'] = np.concatenate(( np.array([x[i]]),predicted_time_stamps)).squeeze()
        data['signal values'] = np.concatenate((np.array([y[i]]),predicted_values)).squeeze()
        if flag_stable:
            flag_stable = interpolator.forward(data)
        else:
            break
    
    return interpolator

# LINEAR MODEL PREDICTION (RECURSIVE -- ONLINE)
@conditional_decorator(timeitflag,timeit)
def online_predicted_interpolation(eta,rho,x,y,lag=1,horizon=1):
    interpolator = OnlineInterpolator(2*rho-1,2*(rho-1),rho,eta)
    predictor = LinearModel(lag=lag,horizon=horizon)
    flag_stable = True
    train_lagged_values = np.zeros(lag)
    for i in range(len(x)):
        data = {}
        lagged_values = np.pad( y[:i+1] , (lag,0),'constant',constant_values=0 )[-lag:]
        predicted_time_stamps, predicted_values = predictor.forecast(x[i],lagged_values)
        data['time stamps'] = np.concatenate(( np.array([x[i]]),predicted_time_stamps)).squeeze()
        data['signal values'] = np.concatenate((np.array([y[i]]),predicted_values)).squeeze()
        if flag_stable:
            flag_stable = interpolator.forward(data)
            if i >= horizon-1:
                train_reference_values = y[:i+1][-horizon:]
                predictor.update_parameters(train_lagged_values,train_reference_values)
                train_lagged_values = np.concatenate(( train_lagged_values[1:], np.array([train_reference_values[0]]) ))
        else:
            break

    return interpolator

# AVERAGE CUMULATIVE NORM
def run_acn(L=1,H=1,P=2,
            eta=0.01,
            rho=2,
            num_lags=2,
            num_samples=300,
            num_training_samples=500,
            percentage=None,
            fontsize=18,
            ymax = 10,
            process = 'ar'
            ):

    ## training data
    mean,std,training_signal_values = data_points(num_samples=num_training_samples,
                                            num_lags=num_lags,
                                            percentage=percentage,
                                            flag_training=True,
                                            process=process)

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags,
                                             percentage=percentage,
                                             mean=mean,
                                             std=std,
                                             process=process)

    myopic = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=1)
    lookahead = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=L+1)
    zoh = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)
    ppi = pretrained_predicted_interpolation(eta,rho,time_stamps,signal_values,training_signal_values,lag=P,horizon=H)
    opi = online_predicted_interpolation(eta,rho,time_stamps,signal_values,lag=P,horizon=H)

    myopic_mean, myopic_std = compute_cumulative_norm_moments(myopic)
    lookahead_mean, delayed_std = compute_cumulative_norm_moments(lookahead)
    zoh_mean, zoh_std = compute_cumulative_norm_moments(zoh)
    ppi_mean, ppi_std = compute_cumulative_norm_moments(ppi)
    opi_mean, opi_std = compute_cumulative_norm_moments(opi)

    ## plots
    add_curve(myopic_mean,myopic_std,color='r',marker='*',label='Myopic',flag_std=False)
    add_curve(lookahead_mean,delayed_std,color='b',linestyle='-.',label=f'Delay (L={L})')
    add_curve(zoh_mean,zoh_std,color='m',linestyle=':',label=f'ZOH (H={H})')
    add_curve(ppi_mean,ppi_std,color='k',label=f'LS Linear (H={H})')
    add_curve(opi_mean,opi_std,color='g',linestyle='--',label=f'RLS Linear (H={H})')
    
    # plt.ylim([0,10*ymax])
    plt.ylim([0,ymax])
    plt.xlim([1,num_samples-1])
    plt.xticks(fontsize=fontsize-2)
    plt.yticks(fontsize=fontsize-2)
    plt.xlabel('Time step $t$',fontsize=fontsize)
    plt.ylabel('ACN',fontsize=fontsize)
    plt.legend(fontsize=fontsize,framealpha=0.7,loc='upper right')
    plt.tight_layout()
    plt.show()

# AVERAGE CUMULATIVE COST
def run_acc(L=1,H=1,P=2,
            eta=0.01,
            rho=2,
            num_lags=2,
            num_samples=300,
            num_training_samples=500,
            percentage=None,
            fontsize=18,
            ymax = 4,
            process = 'ar'
            ):
    
    ## training data
    mean,std,training_signal_values = data_points(num_samples=num_training_samples,
                                            num_lags=num_lags,
                                            percentage=percentage,
                                            flag_training=True,
                                            process=process)

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags,
                                             percentage=percentage,
                                             mean=mean,
                                             std=std,
                                             process=process)

    myopic = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=1)
    lookahead = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=L+1)
    zoh = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)
    ppi = pretrained_predicted_interpolation(eta,rho,time_stamps,signal_values,training_signal_values,lag=P,horizon=H)
    opi = online_predicted_interpolation(eta,rho,time_stamps,signal_values,lag=P,horizon=H)

    myopic_mean, myopic_std = compute_cumulative_cost_moments(myopic,time_stamps,signal_values)
    lookahead_mean, delayed_std = compute_cumulative_cost_moments(lookahead,time_stamps,signal_values)
    zoh_mean, zoh_std = compute_cumulative_cost_moments(zoh,time_stamps,signal_values)
    ppi_mean, ppi_std = compute_cumulative_cost_moments(ppi,time_stamps,signal_values)
    opi_mean, opi_std = compute_cumulative_cost_moments(opi,time_stamps,signal_values)

    ## plots
    add_curve(myopic_mean,myopic_std,color='r',marker='*',label='Myopic',flag_std=False)
    add_curve(lookahead_mean,delayed_std,color='b',linestyle='-.',label=f'Delay (L={L})')
    add_curve(zoh_mean,zoh_std,color='m',linestyle=':',label=f'ZOH (H={H})')
    add_curve(ppi_mean,ppi_std,color='k',label=f'LS Linear (H={H})')
    add_curve(opi_mean,opi_std,color='g',linestyle='--',label=f'RLS Linear (H={H})')
    
    plt.ylim([0,ymax])
    plt.xlim([1,num_samples-1])
    plt.xticks(fontsize=fontsize-2)
    plt.yticks(fontsize=fontsize-2)
    plt.xlabel('Time step $t$',fontsize=fontsize)
    plt.ylabel('ACC',fontsize=fontsize)
    plt.legend(fontsize=fontsize,framealpha=0.7,loc='upper right')
    plt.tight_layout()
    plt.show()


# MAIN DOCUMENT
def experiment00():
    """ Average cumulative norm | L=H=1 | P=2 | uniform sampling """
    run_acn()

def experiment01():
    """ Average cumulative norm | L=H=3 | P=2 | uniform sampling """
    run_acn(L=3,H=3)

def experiment02():
    """ Average cumulativecost | L=H=1 | P=2 | uniform sampling """
    run_acc()

def experiment03():
    """ Average cumulativecost | L=H=3 | P=2 | uniform sampling """
    run_acc(L=3,H=3)

def experiment04(delay = 1):
    """ Intro - unstable strategy """
    np.random.seed(42)
    T = 1000
    w = 1/4
    P = 2
    x = np.linspace(0,P*np.pi,T)
    l = 1.3
    s = 0.5
    r = lambda x,s=s: s*np.log(x)
    phi1 = lambda x,w=w: r(x) * np.sin((1/w)*x)
    phi2 = lambda x,w=w: r(x) * np.cos((1/w)*x)
    R = 20
    idx = np.random.permutation(T)
    s = x[np.sort(idx[:int(T/R)])]
    xx = np.linspace(s[0],s[-1],T)

    # (gaussian) Noisy data generation
    var = 0.1
    noise1 = np.random.normal(0,var,int(T/R))
    noise2 = np.random.normal(0,var,int(T/R))
    eta=0.001
    rho=2
    order=2*rho-1
    smooth=order-1
    interpolator1=OnlineInterpolator(order,smooth,rho,eta)
    interpolator2=OnlineInterpolator(order,smooth,rho,eta)

    y1 = phi1(s)+noise1
    y2 = phi2(s)+noise2

    for i in range(len(s)):
        data1 = {}
        data1['time stamps'] = s[i:i+delay]
        data1['signal values'] = y1[i:i+delay]
        data2 = {}
        data2['time stamps'] = s[i:i+delay]
        data2['signal values'] = y2[i:i+delay]
        interpolator1.forward(data1)
        interpolator2.forward(data2)

    f1 = interpolator1.spline.reconstruct_spline(xx)
    f2 = interpolator2.spline.reconstruct_spline(xx)
    ax = plt.figure(0).add_subplot(projection='3d')
    ax.plot(xx,phi1(xx),phi2(xx),'--c',label='$\mathbf{\Psi}(x)$')
    ax.plot(x, f1, f2,'gray',label='$\mathbf{f}(x)$')
    ax.plot(s, y1, y2,'k.')
    ax.set_axis_off()
    ## to rotate the axis view
    ax.view_init(30, -60)
    ## axis limits
    ax.axes.set_xlim3d(left=0, right=7) 
    ax.axes.set_ylim3d(bottom=-2, top=2) 
    ax.axes.set_zlim3d(bottom=-1, top=1) 
    plt.show()

def experiment05():
    """ Intro - stable strategy """
    experiment04(delay=3)

# APPENDICES
## NON-UNIFORM
def experiment10():
    """ Average cumulative norm | L=H=1 | P=2 | non-uniform sampling | 10%  """
    run_acn(num_samples=600,percentage=0.1,ymax=15)

def experiment11():
    """ Average cumulative norm | L=H=3 | P=2 | non-uniform sampling | 10% """
    run_acn(num_samples=600,percentage=0.1,ymax=15,L=3,H=3)

def experiment12():
    """ Average cumulative cost | L=H=1 | P=2 | non-uniform sampling | 10%  """
    run_acc(num_samples=600,percentage=0.1,ymax=15)

def experiment13():
    """ Average cumulative cost| L=H=3 | P=2 | non-uniform sampling | 10% """
    run_acc(num_samples=600,percentage=0.1,ymax=15,L=3,H=3)

## RHO = 3
def experiment20():
    """ Average cumulative norm | L=H=1 | P=2 | uniform sampling | rho = 3 """
    run_acn(rho=3,num_samples=50)

def experiment21():
    """ Average cumulative norm | L=H=3 | P=2 | uniform sampling | rho = 3 """
    run_acn(rho=3,L=3,H=3)

def experiment22():
    """ Average cumulative cost | L=H=1 | P=2 | uniform sampling | rho = 3 """
    run_acc(rho=3,num_samples=50)

def experiment23():
    """ Average cumulative cost | L=H=3 | P=2 | uniform sampling | rho = 3 """
    run_acc(rho=3,L=3,H=3)

## INTERPOLATION PLOTS | VISUALLY
def experiment30():
    """ Unstable behaviour myopic"""
    eta = 0.01
    rho = 2
    num_samples = 8
    num_lags = 2

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags)
    
    ## myopic interpolation
    interpolator = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=1)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    df = interpolator.spline.reconstruct_spline(t,k=1)
    ddf = interpolator.spline.reconstruct_spline(t,k=2)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,df,'g',label='1st Derivative')
    plt.plot(t,ddf,'--c',label='2nd Derivative')
    plt.plot(time_stamps,signal_values,'k*')
    plt.legend()
    plt.ylim([-50,50])
    plt.xlim([0,7])
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

def experiment31():
    """Stable Cubic Hermite"""
    eta = 0.01
    rho = 2
    order = 3
    smooth = 1
    num_samples = 10
    num_lags = 2

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags)
    
    ## cubic Hermite interpolator
    interpolator = OnlineInterpolator(order,smooth,rho,eta)
    for x,y in zip(time_stamps,signal_values):
        data = {}
        data['time stamps'] = np.array([x])
        data['signal values'] = np.array([y])
        interpolator.forward(data)

    ## plots
    fontsize=18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    df = interpolator.spline.reconstruct_spline(t,k=1)
    ddf = interpolator.spline.reconstruct_spline(t,k=2)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,df,'g',label='1st Derivative')
    plt.plot(t,ddf,'.c',label='2nd Derivative',markersize=1.0)
    plt.plot(time_stamps,signal_values,'k*')
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()

def experiment32(L=1):
    """ Stabilization through delay L=1 """
    eta = 0.01
    rho = 2
    num_samples = 20
    num_lags = 2

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags)
    
    ## lookahead interpolation
    interpolator = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=1+L)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    df = interpolator.spline.reconstruct_spline(t,k=1)
    ddf = interpolator.spline.reconstruct_spline(t,k=2)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,df,'g',label='1st Derivative')
    plt.plot(t,ddf,'--c',label='2nd Derivative')
    plt.plot(time_stamps,signal_values,'k*')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

def experiment33():
    """ Stabilization through delay L=3 """
    experiment32(L=3)

def experiment34():
    """ Stabilization through zero-order hold prediction H=1"""
    eta = 0.01
    rho = 2
    num_samples = 20
    num_lags = 2

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags)
    
    ## zero-order hold interpolation
    H = 1
    interpolator = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    df = interpolator.spline.reconstruct_spline(t,k=1)
    ddf = interpolator.spline.reconstruct_spline(t,k=2)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,df,'g',label='1st Derivative')
    plt.plot(t,ddf,'--c',label='2nd Derivative')
    plt.plot(time_stamps,signal_values,'k*')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

def experiment35():
    """ Stabilization through zero-order hold prediction H=3 | non-uniform data """
    eta = 0.01
    rho = 2
    num_samples = 20
    num_lags = 2

    ## trajectory data
    percentage_missing = 0.4
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags,
                                             percentage=percentage_missing)
    
    ## zero-order hold interpolation
    H = 3
    interpolator = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    df = interpolator.spline.reconstruct_spline(t,k=1)
    ddf = interpolator.spline.reconstruct_spline(t,k=2)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,df,'g',label='1st Derivative')
    plt.plot(t,ddf,'--c',label='2nd Derivative')
    plt.plot(time_stamps,signal_values,'k*')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    


def experiment36():
    """ Stabilization through linear model prediction (online) | H=1, P=2 """
    eta = 0.01
    rho = 2
    num_samples = 100
    num_lags = 2

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags)
    
    ## zero-order hold interpolation
    H = 1
    P = 2
    interpolator = online_predicted_interpolation(eta,rho,time_stamps,signal_values,lag=P,horizon=H)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    df = interpolator.spline.reconstruct_spline(t,k=1)
    ddf = interpolator.spline.reconstruct_spline(t,k=2)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,df,'g',label='1st Derivative')
    plt.plot(t,ddf,'--c',label='2nd Derivative')
    plt.plot(time_stamps,signal_values,'k*')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

def experiment37():
    """ Stabilization through zero-order hold prediction H=3 | rho = 3 """
    eta = 0.01
    rho = 3
    num_samples = 20
    num_lags = 2

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags)
    
    ## zero-order hold interpolation
    H = 3
    interpolator = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],1000)
    f = interpolator.spline.reconstruct_spline(t)
    d3f = interpolator.spline.reconstruct_spline(t,k=3)
    d4f = interpolator.spline.reconstruct_spline(t,k=4)
    plt.plot(t,f,'b',label='Interpolation')
    plt.plot(t,d3f,'g',label='3rd Derivative')
    plt.plot(t,d4f,'--c',label='4th Derivative')
    plt.plot(time_stamps,signal_values,'k*')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

## MERTON DISCRETE JUMP-DIFUSION PROCESS
def experiment40():
    ''' Merton's discrete jump diffusion data
    '''
    num_samples = 500
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             process='merton')
    ## plot
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],10000)
    plt.plot(time_stamps,signal_values,'k*',label='Data points')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

def experiment41():
    run_acn(num_samples=500,process='merton',ymax=18)

def experiment42():
    run_acn(num_samples=500,process='merton',ymax=18,L=3,H=3)

def experiment43():
    run_acc(num_samples=500,process='merton',ymax=3)

def experiment44():
    run_acc(num_samples=500,process='merton',ymax=3,L=3,H=3)

def experiment45(H=1):
    """ Stabilization through zero-order hold prediction H=1"""
    eta = 0.01
    rho = 2
    num_samples = 500

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             process='merton')
    
    ## zero-order hold interpolation
    interpolator = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)

    ## plots
    fontsize = 18
    t = np.linspace(0,time_stamps[-1],10000)
    f = interpolator.spline.reconstruct_spline(t)
    plt.plot(t,f,'b',label=f'ZOH (H={H})')
    plt.plot(time_stamps,signal_values,'k*',label='Data points')
    plt.legend()
    plt.xlabel('Time',fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()    

def experiment46():
    experiment45(H=3)

## TABLE: TOTAL COST PER INTERPOLATED PIECE
def experiment50(H=1,L=1,P=2):
    ''' print table with total cost of each interpolator for the experiment in fig. 4
    '''
    eta = 0.01
    rho = 2
    num_samples = 300
    num_lags = 2
    num_training_samples = 500

    ## training data
    mean,std,training_signal_values = data_points(num_samples=num_training_samples,
                                            num_lags=num_lags,
                                            flag_training=True)

    ## trajectory data
    time_stamps, signal_values = data_points(num_samples=num_samples,
                                             num_lags=num_lags,
                                             mean=mean,
                                             std=std)

    ## reconstruction 
    myopic = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=1)
    lookahead = lookahead_interpolation(eta,rho,time_stamps,signal_values,delay=L+1)
    zoh = zoh_interpolation(eta,rho,time_stamps,signal_values,horizon=H)
    ppi = pretrained_predicted_interpolation(eta,rho,time_stamps,signal_values,training_signal_values,lag=P,horizon=H)
    opi = online_predicted_interpolation(eta,rho,time_stamps,signal_values,lag=P,horizon=H)

    order = 2*rho-1
    smooth = order-1
    batch = BatchInterpolator(order,smooth,rho,eta)
    data = {'time stamps':time_stamps,'signal values':signal_values}
    batch.solve(data)

    ## total losses / per function section
    print('TOTAL COST / SPLINE PIECE')
    print('myopic',compute_total_cost(myopic,time_stamps,signal_values)/len(myopic.spline.pieces))
    print(f'delay (L={L})',compute_total_cost(lookahead,time_stamps,signal_values)/len(lookahead.spline.pieces))
    print(f'zoh (H={H})',compute_total_cost(zoh,time_stamps,signal_values)/len(zoh.spline.pieces))
    print(f'ppi (H={H})',compute_total_cost(ppi,time_stamps,signal_values)/len(ppi.spline.pieces))
    print(f'opi (H={H})',compute_total_cost(opi,time_stamps,signal_values)/len(opi.spline.pieces))
    print(f'batch',compute_total_cost(batch,time_stamps,signal_values)/len(batch.spline.pieces))

def experiment51():
    experiment50(H=3,L=3)

## STABILITY | CONTROLLABILITY
def experiment100():
    """ Result 1 """
    rho = 2
    ## sybolic sequence of ut's
    u = {}
    for t in range(2,2*rho+1):
        u[f'{t}'] = sp.symbols(f'u{t}')
    ## constructing the B matrix
    B = construct_B(rho)
    ## constructing the sequence of matrices At
    A = {}
    for t in range(2,2*rho+1):
        A[f'{t}'] = construct_A(u[f'{t}'],rho)
    ## constructing matrix M
    M = construct_M(A,B)
    print('M = ')
    sp.pprint(M)
    print('det(M) = ')
    sp.pprint( M.det() )
    print('There are no solutions for det(M)=0 such that u_2, u_3, u_4 > 0')
    print( sp.solve(M.det(),u['2']>0,u['3']>0,u['4']>0) )

def experiment101(rho):
    """ Display matrix A"""
    u = sp.symbols('u')
    A = construct_A(u,rho)
    print('A = ')
    sp.pprint(A)

def experiment102(rho,flag_uniform):
    u = {}
    for t in range(2,2*rho+1):
        u[f'{t}'] = sp.symbols(f'u{t}')
    if flag_uniform:
        u['2'] = sp.symbols('u')
        for t in range(2,2*rho):
            u[f'{t+1}'] = u[f'{t}'] 
    
    B = construct_B(rho)
    ## constructing the sequence of matrices At
    A = {}
    for t in range(2,2*rho+1):
        A[f'{t}'] = construct_A(u[f'{t}'],rho)

    M = construct_M(A,B)
    print('M = ')
    sp.pprint(M)

def experiment103():
    ## from rho=7 onwards we start observing numerical instabilities... 
    max_iterations = 1000
    lower_limit = 1e-3
    upper_limit = 2
    max_rho_value = 10

    rho_values = [ rho for rho in range(3,max_rho_value+1) ]
    for rho in rho_values:
        B = construct_B(rho,_flag_numeric=True)
        for _ in range(max_iterations): 
            A = {}
            for t in range(2,2*rho+1):
                u = np.random.uniform(lower_limit,upper_limit) 
                A[f'{t}'] = construct_A(u,rho,_flag_numeric=True)
            M = construct_M(A,B,_flag_numeric=True)

            # To add a bit of numerical stability (when computing the rank)
            nM = row_normalization(M)
            if (np.linalg.matrix_rank(nM,tol=1e-32) != nM.shape[0]) and (np.linalg.det(nM) == 0):
                return print(f'Conjecture does not hold at rho = {rho}')
            else:
                return print('Conjecture holds')

# PARSER
def run_experiment(args):
    idnumber = args.id
    if idnumber == 'result':
        globals()['experiment100']()
    elif idnumber == 'A':
        globals()['experiment101'](args.rho)
    elif idnumber == 'M':
        globals()['experiment102'](args.rho,bool(args.uniform))
    elif idnumber == 'conjecture':
        globals()['experiment103']()

    else:
        globals()['experiment'+idnumber]()
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-id',type=str,default='00',help='experiment number (from 00 to 99)')
    parser.add_argument('-rho',type=int,default=2)
    parser.add_argument('-u','--uniform',type=int,default=1,choices=[0,1],
                        help='uniform sampling:1 , non-uniform sampling 0')
    args = parser.parse_args()

    run_experiment(args)