#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "linalg.h"
#include "nn.h"

// 1D Preconditioner: Set C1_inv, C2, and K_phi
static void set_CK(shallow_nn *nn, double *theta, double **C1_inv, double **C2, double **K_phi);
// 1D Preconditioner: Multiplication by bar{P}
static void bP_mult(shallow_nn *nn, double **C1_inv, double **C2, double **K_phi, double *aa, double *bb);
// 1D Preconditioner: Multiplication by P
static void P_mult(shallow_nn *nn, double **C1_inv, double **C2, double **K_phi, double *a, double *b);
// 1D Preconditioner: Preconditioned conjugate gradient method
static void P_pcg(shallow_nn *nn, double **C1_inv, double **C2, double **K_phi, double **K, double *f, double *c);

// New shallow_nn structure
// Inputs - d, N, n, problem_type, fun_f, phi, dphi
// Outputs - nn: shallow_nn structure
shallow_nn *new_shallow_nn(int d, int N, int n, int problem_type, double (*fun_f)(double *), double (*fun_u)(double *), double (*phi)(double), double (*dphi)(double)){
    int i = 0, j = 0;

    shallow_nn *nn = (shallow_nn *) malloc(sizeof(shallow_nn));

    // d: Input dimension
    nn->d = d;
    // N: Number of sampling points for numerical integration
    nn->N = N;
    // n: Number of neurons
    nn->n = n;

    // problem_type: 0 - L^2 function approximation, 1 - Neumann BVP
    nn->problem_type = problem_type;

    // x[N][d]: Coordinates of sampling points
    nn->x = new_mat(N, d);
    if(d == 1){ 
        // d = 1: uniform partition of [0,1]
        for(i=0; i<N; i++){
            nn->x[i][0] = (1.0 * i) / (N - 1);
        }
    }
    else{
        // d > 1: Halton sequence for quasi-Monte Carlo integration
        halton_seq(d, N, nn->x);
    }
    // Function values
    nn->f = new_vec(N);
    nn->u = new_vec(N);
    for(i=0; i<N; i++){
        nn->f[i] = fun_f(nn->x[i]);
        nn->u[i] = fun_u(nn->x[i]);
    }

    // double phi(double): Activation function
    nn->phi = phi;
    // double dphi(double): Derivative of actiation function
    nn->dphi = dphi;

    return nn;
}

// Free shallow_nn structure
// Inputs - nn: shallow_nn structure
void free_shallow_nn(shallow_nn *nn){
    free_mat(nn->x, nn->N);
    free_vec(nn->f), free_vec(nn->u);
    free(nn);
}

// He initialization for the parameter c
// Inputs - nn: shallow_nn structure
// Outputs - c[n]
double *new_c_He(shallow_nn *nn){
    int n = nn->n;
    int i = 0;

    double *c = new_vec(n);

    //srand(time(0));
    for(i=0; i<n; i++){
        c[i] = normrnd(0, sqrt(2.0/n));
    }

    return c;
}

// He initalization for theta
// Inputs - nn: shallow_nn structure
// Outputs - theta[(d+1)*n]
double *new_theta_He(shallow_nn *nn){
    int d = nn->d, n = nn->n;
    int i = 0, j = 0;

    double *theta = new_vec((d+1)*n);
    
    for(i=0; i<n; i++){
        for(j=0; j<d+1; j++){
            theta[(d+1)*i + j] = normrnd(0, sqrt(2.0/d));
        }
    }

    return theta;
}

// NN function defined by the parameters in the shallow_nn structure
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: parameters
// Outputs - u_nn[N]: NN function values, dudc[n][N]: c-partial derivative, dudtheta[(d+1)*n][N]: theta-partial derivative 
//           grad_u_nn[d][N], grad_dudc[n][d][N], grad_dudtheta[(d+1)*n][d][N] gradients with respect to x     
void nn_fun(shallow_nn *nn, double *c, double *theta, double *u_nn, double **dudc, double **dudtheta, double **grad_u_nn, double ***grad_dudc, double ***grad_dudtheta){
    int i = 0, j = 0, k = 0, kk = 0;
    double temp = 0;

    int d = nn->d, N = nn->N, n = nn->n;
    double **x = nn->x;
    
    double **thetax = new_mat(n, N);

    theta_mult(nn, theta, thetax);
    for(i=0; i<n; i++){
        for(j=0; j<N; j++){
            dudc[i][j] = nn->phi(thetax[i][j]);
        }
    }

    for(j=0; j<N; j++){
        u_nn[j] = 0;
        for(i=0; i<n; i++){
            u_nn[j] += c[i] * dudc[i][j];
        }
    }

    for(i=0; i<n; i++){
        for(j=0; j<N; j++){
            temp = c[i] * nn->dphi(thetax[i][j]);
            // du/dW
            for(k=0; k<d; k++){
                dudtheta[(d+1)*i + k][j] = temp * x[j][k];
            }
            // du/db
            dudtheta[(d+1)*i + d][j] = temp;
        }
    }

    // Neumann BVP
    if(nn->problem_type != 0){
        // grad_u_nn
        for(k=0; k<d; k++){
            for(j=0; j<N; j++){
                grad_u_nn[k][j] = 0;
                for(i=0; i<n; i++){
                    grad_u_nn[k][j] += c[i] * nn->dphi(thetax[i][j]) * theta[(d+1)*i + k];
                }
            }
        }

        // grad_dudc
        for(k=0; k<d; k++){
            for(i=0; i<n; i++){
                for(j=0; j<N; j++){
                    grad_dudc[i][k][j] = nn->dphi(thetax[i][j]) * theta[(d+1)*i + k];
                }
            }
        }
        // grad_dudtheta
        for(k=0; k<d; k++){
            for(i=0; i<n; i++){
                for(j=0; j<N; j++){
                    temp = c[i] * nn->dphi(thetax[i][j]);
                    // du/dW
                    for(kk=0; kk<d; kk++){
                        if(kk == k){
                            grad_dudtheta[(d+1)*i + kk][k][j] = temp;
                        }
                        else{
                            grad_dudtheta[(d+1)*i + kk][k][j] = 0;
                        }                      
                    }
                    // du/db
                    grad_dudtheta[(d+1)*i + d][k][j] = 0;
                }
            }
        }
    }

    free_mat(thetax, n);
    return;
}

// Loss function
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: parameters, nograd: not compute gradient if nograd != 0
// Outputs - E: loss function value, dEdc[n]: c-partial derivative,
//          dEdtheta[(d+1)*n]: theta-partial derivative
double loss_fun(shallow_nn *nn, double *c, double *theta, int nograd, double *dEdc, double *dEdtheta){
    double E = 0;
    int d = nn->d, N = nn->N, n = nn->n;
    
    double *u_nn = new_vec(N), **dudc = new_mat(n, N), **dudtheta = new_mat((d+1)*n, N);
    double **grad_u_nn = new_mat(d, N), ***grad_dudc = new_ten(n, d, N), ***grad_dudtheta = new_ten((d+1)*n, d, N);

    int i = 0;

    nn_fun(nn, c, theta, u_nn, dudc, dudtheta, grad_u_nn, grad_dudc, grad_dudtheta);

    // Loss function value
    E = 0.5 * a_innerprod(u_nn, grad_u_nn, u_nn, grad_u_nn, nn) - L2_innerprod(nn->f, u_nn, nn);
    
    // "nograd" option
    if(nograd != 0){
        goto loss_fun_end;
    }

    // dEdc
    for(i=0; i<n; i++){
        dEdc[i] = a_innerprod(dudc[i], grad_dudc[i], u_nn, grad_u_nn, nn) - L2_innerprod(dudc[i], nn->f, nn);
    }

    // dEdtheta
    for(i=0; i<(d+1)*n; i++){
        dEdtheta[i] = a_innerprod(dudtheta[i], grad_dudtheta[i], u_nn, grad_u_nn, nn) - L2_innerprod(dudtheta[i], nn->f, nn);
    }

    loss_fun_end:
    free_vec(u_nn), free_mat(dudc, n), free_mat(dudtheta, (d+1)*n);
    free_mat(grad_u_nn, d), free_ten(grad_dudc, n, d), free_ten(grad_dudtheta, (d+1)*n, d);
    return E;
}

// Train the neural network by gradient descent
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: initial parameters, num_epochs: number of epochs
// Outputs - c[n], theta[(d+1)*n]: parameters, ret: final loss, nn: shallow_nn structure
//           E_decay[num_epochs + 1]: loss decay
double train_GD(shallow_nn *nn, double *c, double *theta, int num_epochs, double *E_decay){
    double E = 0;
    int d = nn->d, n = nn->n;

    // Learning rate
    double lr = 1, MIN_LR = 0.0001;
    double *c_old, *theta_old, *dEdc, *dEdtheta;
    
    int iter = 0, i = 0;

    c_old = new_vec(n);
    theta_old = new_vec((d+1)*n);
    dEdc = new_vec(n);
    dEdtheta = new_vec((d+1)*n);

    // GD iterations
    for(iter=0; iter<num_epochs; iter++){
        // Update the iterates
        vec_cpy(c, c_old, n);
        vec_cpy(theta, theta_old, (d+1)*n);

        E_decay[iter] = loss_fun(nn, c_old, theta_old, 0, dEdc, dEdtheta);
        //printf("Epoch %d: %.6e, LR: %.6e\n", iter, E_decay[iter], lr);

        /*
        // Fixed learning rate
        vec_lincom(c_old, dEdc, 1, -lr, c, n);
        vec_lincom(theta_old, dEdtheta, 1, -lr, theta, n);
        */

        // Backtracking for the learning rate
        lr *= 2;
        while(lr > MIN_LR){
            vec_lincom(c_old, dEdc, 1, -lr, c, n);
            vec_lincom(theta_old, dEdtheta, 1, -lr, theta, (d+1)*n);
            E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
            // Backtracking criterion
            if(E <= E_decay[iter]){
                break;
            }
            lr /= 2;
        }
    }

    E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
    for(i=iter; i<=num_epochs; i++){
        E_decay[i] = E;
    }
    //printf("Epoch %d: %.6e, LR: %.6e\n", num_epochs, E, lr);

    free_vec(c_old), free_vec(theta_old);
    free_vec(dEdc), free_vec(dEdtheta);
    return E;
}

// Train the neural network by Adam
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: initial parameters, num_epochs: number of epochs
// Outputs - c[n], theta[(d+1)*n]: parameters, ret: final loss, nn: shallow_nn structure
//           E_decay[num_epochs + 1]: loss decay
double train_Adam(shallow_nn *nn, double *c, double *theta, int num_epochs, double *E_decay){
    double E = 0;
    int d = nn->d, n = nn->n;

    // Learning rate
    double lr = 1, MIN_LR = 0.0001;
    // Exponential decay rates
    double rho1 = 0.9, rho2 = 0.999;
    // Numerical stabilization constant
    double delta = 0.00000001;
    double *c_old, *theta_old, *dEdc, *dEdtheta;
    // Moment variables
    double *c_s, *c_r, *theta_s, *theta_r;
    
    int iter = 0, i = 0;

    c_old = new_vec(n);
    theta_old = new_vec((d+1)*n);
    dEdc = new_vec(n);
    dEdtheta = new_vec((d+1)*n);

    c_s = new_vec(n);
    c_r = new_vec(n);
    theta_s = new_vec((d+1)*n);
    theta_r = new_vec((d+1)*n);

    // GD iterations
    for(iter=0; iter<num_epochs; iter++){
        // Update the iterates
        vec_cpy(c, c_old, n);
        vec_cpy(theta, theta_old, (d+1)*n);

        E_decay[iter] = loss_fun(nn, c_old, theta_old, 0, dEdc, dEdtheta);
        //printf("Epoch %d: %.6e, LR: %.6e\n", iter, E_decay[iter], lr);

        // Udpate biased moment estiamtes
        for(i=0; i<n; i++){
            c_s[i] = rho1 * c_s[i] + (1-rho1) * dEdc[i];
            c_r[i] = rho2 * c_r[i] + (1-rho2) * dEdc[i]*dEdc[i];
        }
        for(i=0; i<(d+1)*n; i++){
            theta_s[i] = rho1 * theta_s[i] + (1-rho1) * dEdtheta[i];
            theta_r[i] = rho2 * theta_r[i] + (1-rho2) * dEdtheta[i]*dEdtheta[i];
        }

        /*
        // Fixed learning rate
        for(i=0; i<n; i++){
            c[i] = c_old[i] - lr * ((c_s[i]/(1-pow(rho1,iter+1))) / (sqrt(c_r[i]/(1-pow(rho2,iter+1))) + delta));
        }
        for(i=0; i<(d+1)*n; i++){
            theta[i] = theta_old[i] - lr * ((theta_s[i]/(1-pow(rho1,iter+1))) / (sqrt(theta_r[i]/(1-pow(rho2,iter+1))) + delta));
        }
        */

        // Backtracking for the learning rate
        lr *= 2;
        while(lr > MIN_LR){
            for(i=0; i<n; i++){
                c[i] = c_old[i] - lr * ((c_s[i]/(1-pow(rho1,iter+1))) / (sqrt(c_r[i]/(1-pow(rho2,iter+1))) + delta));
            }
            for(i=0; i<(d+1)*n; i++){
                theta[i] = theta_old[i] - lr * ((theta_s[i]/(1-pow(rho1,iter+1))) / (sqrt(theta_r[i]/(1-pow(rho2,iter+1))) + delta));
            }
            E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
            // Backtracking criterion
            if(E <= E_decay[iter]){
                break;
            }
            lr /= 2;
        }

        for(i=0; i<n; i++){
            c[i] = c_old[i] - lr * ((c_s[i]/(1-pow(rho1,iter+1))) / (sqrt(c_r[i]/(1-pow(rho2,iter+1))) + delta));
        }
        for(i=0; i<(d+1)*n; i++){
            theta[i] = theta_old[i] - lr * ((theta_s[i]/(1-pow(rho1,iter+1))) / (sqrt(theta_r[i]/(1-pow(rho2,iter+1))) + delta));
        }
    }

    E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
    for(i=iter; i<=num_epochs; i++){
        E_decay[i] = E;
    }
    //printf("Epoch %d: %.6e, LR: %.6e\n", num_epochs, E, lr);

    free_vec(c_old), free_vec(theta_old);
    free_vec(dEdc), free_vec(dEdtheta);
    free_vec(c_s), free_vec(c_r), free_vec(theta_s), free_vec(theta_r);
    return E;
}

// Train the neural network by least squares/gradient descent
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: initial parameters, num_epochs: number of epochs
// Outputs - c[n], theta[(d+1)*n]: parameters, ret: final loss, nn: shallow_nn structure
//           E_decay[num_epochs + 1]: loss decay
double train_LSGD(shallow_nn *nn, double *c, double *theta, int num_epochs, double *E_decay){
    double E = 0, E_old = 0, E_c = 0;
    int d = nn->d, n = nn->n;
    
    // Initial learning rate
    double lr = 1;
    double *theta_old, *dEdc, *dEdtheta;
    
    int iter = 0, i = 0;

    theta_old = new_vec((d+1)*n);
    dEdc = new_vec(n);
    dEdtheta = new_vec((d+1)*n);

    // c-minimizing GD iterations
    E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
    for(iter=0; iter<num_epochs; iter++){
        // Update the iterates
        vec_cpy(theta, theta_old, (d+1)*n);
        E_old = E;
        E_decay[iter] = E_old;
        //printf("Epoch %d: %.6e, LR: %.6e\n", iter, E_decay[iter], lr);

        // Find the optimal c when theta is fixed.
        c_minimize(nn, c, theta_old);
        E_c = loss_fun(nn, c, theta_old, 0, dEdc, dEdtheta);
        
        /*
        // Fixed learning rate
        vec_lincom(theta_old, dEdtheta, 1, -lr, theta, (d+1)*n);
        E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
        */
        
        // Backtracking for the learning rate
        lr *= 2;
        while(lr > 0.0001){
            vec_lincom(theta_old, dEdtheta, 1, -lr, theta, (d+1)*n);
            E = loss_fun(nn, c, theta, 1, dEdc, dEdtheta);
            // Backtracking criterion
            if(E <= E_c){
                break;
            }
            lr /= 2;
        }
    }

    for(i=iter; i<=num_epochs; i++){
        E_decay[i] = E;
    }
    //printf("Epoch %d: %.6e, LR: %.6e\n", num_epochs, E, lr);

    free_vec(theta_old);
    free_vec(dEdc), free_vec(dEdtheta);
    return E;
}

// Multiplication by theta = (W, b): theta * x1 = W * x + b
// Inputs - nn: shallow_nn structure, theta[(d+1)*n]
// Output - thetax[n][N]
void theta_mult(shallow_nn *nn, double *theta, double **thetax){
    int i = 0, j  = 0, k = 0;

    int d = nn->d, N = nn->N, n = nn->n;
    double **x = nn->x;

    for(i=0; i<n; i++){
        for(j=0; j<N; j++){           
            thetax[i][j] = 0;
            // W * x
            for(k=0; k<d; k++){
                thetax[i][j] += theta[(d+1)*i + k] * x[j][k];
            }
            // b
            thetax[i][j] += theta[(d+1)*i + d];
        }
    }

    return;
}

// Find the optimal parameter c when theta is given.
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: initial parameters
// Outputs - c[n]: optimized parameter
void c_minimize(shallow_nn *nn, double *c, double *theta){
    int d = nn->d, N = nn->N, n = nn->n;

    double **thetax, **basis, ***grad_basis;

    // Mass matrix and load vector
    double **K, *f;
    // Matrices for the construction of the preconditioner
    double **C1_inv, **C2, **K_phi;

    int i = 0, j = 0, k = 0;

    K = new_mat(n, n);
    f = new_vec(n);
    
    // Construction of the basis
    thetax = new_mat(n, N);
    basis = new_mat(n, N);
    grad_basis = new_ten(n, d, N);
    theta_mult(nn, theta, thetax);
    for(i=0; i<n; i++){
        for(j=0; j<N; j++){
            basis[i][j] = nn->phi(thetax[i][j]);
        }
    }
    if(nn->problem_type != 0){
        for(i=0; i<n; i++){
            for(k=0; k<d; k++){
                for(j=0; j<N; j++){
                    grad_basis[i][k][j] = nn->dphi(thetax[i][j]) * theta[(d+1)*i + k];
                }
            }
        }
    }

    // Construction of the mass matrix and load vector
    for(i=0; i<n; i++){
        for(j=i; j<n; j++){
            K[i][j] = a_innerprod(basis[i], grad_basis[i], basis[j], grad_basis[j], nn);
            K[j][i] = K[i][j];
        }
        f[i] = L2_innerprod(nn->f, basis[i], nn);
    }

    // Preconditioned conjugate gradient method for solving PMa = Pf
    if(nn->d == 1){
        C1_inv = new_mat(n+2, 3);
        C2 = new_mat(n+2, 3);
        K_phi = new_mat(n+2, 3);

        set_CK(nn, theta, C1_inv, C2, K_phi);
        P_pcg(nn, C1_inv, C2, K_phi, K, f, c);

        free_mat(C1_inv, n+2), free_mat(C2, n+2), free_mat(K_phi, n+2);
    }
    else{
        pcg(K, f, NULL, c, n);
    }

    free_mat(K, n), free_vec(f);
    free_mat(thetax, n), free_mat(basis, n), free_ten(grad_basis, n, d);
    return;
}

// 1D Preconditioner: Set C1_inv, C2, and K_phi
// Inputs - nn: shallow_nn structure, theta[(d+1)*n]: initial parameters
// Outputs - C1_inv[n+2][3], C2[n+2][3], K_phi[n+2][3]
static void set_CK(shallow_nn *nn, double *theta, double **C1_inv, double **C2, double **K_phi){
    int n = nn->n;
    double *nodes, **C1;
    int i = 0;
    double h_MIN = pow(0.1, 4);

    // Nodal points: nodes[0] < nodes[1] < ... < nodes[n-1]
    nodes = new_vec(n);
    for(i=0; i<n; i++){
        nodes[i] = - theta[2*i+1] / theta[2*i];
    }

    // Construction of C1
    C1 = new_mat(n+2, 3);
    C1[0][1] = 1;
    for(i=0; i<n; i++){
        // w_i > 0
        if(theta[2*i] > 0){
            C1[i+1][1] = theta[2*i];
        }
        // w_i < 0
        else{
            C1[i+1][0] = theta[2*i] * (1-nodes[i]);
            C1[i+1][1] = -theta[2*i];
            C1[i+1][2] = -theta[2*i] * nodes[i];
        }
    }
    C1[n+1][1] = 1;

    // Construction of C1_inv
    C1_inv[0][1] = 1 / C1[0][1];
    for(i=0; i<n; i++){
        C1_inv[i+1][1] = 1 / C1[i+1][1];
        C1_inv[i+1][0] = - C1[i+1][0] / C1[i+1][1] / C1[0][1];
        C1_inv[i+1][2] = - C1[i+1][2] / C1[i+1][1] / C1[n+1][1];
    }
    C1_inv[n+1][1] = 1 / C1[n+1][1];

    // Construction of C2
    C2[0][0] = - (1-nodes[0]) / nodes[0];
    C2[0][1] = 1 / nodes[0];
    C2[0][2] = 1;
    C2[1][0] = 1 / nodes[0];
    C2[1][1] = - nodes[1] / (nodes[1] - nodes[0]) / nodes[0];
    C2[1][2] = 1 / (nodes[1] - nodes[0]);
    for(i=2; i<=n-1; i++){
        C2[i][0] = 1 / (nodes[i-1] - nodes[i-2]);
        C2[i][1] = - (nodes[i] - nodes[i-2]) / (nodes[i] - nodes[i-1]) / (nodes[i-1] - nodes[i-2]);
        C2[i][2] = 1 / (nodes[i] - nodes[i-1]);
    }
    C2[n][0] = 1 / (nodes[n-1] - nodes[n-2]);
    C2[n][1] = - (1 - nodes[n-2]) / (1 - nodes[n-1]) / (nodes[n-1] - nodes[n-2]); 
    C2[n+1][0] = 1 / (1 - nodes[n-1]);

    // Construction of K_phi: h_MIN is due to avoid numerical instability.
    // Function approximation
    K_phi[0][1] = fmax(nodes[0], h_MIN) / 3;
    K_phi[0][2] = fmax(nodes[0], h_MIN) / 6;
    K_phi[1][0] = fmax(nodes[0], h_MIN) / 6;
    K_phi[1][1] = fmax(nodes[1], 2*h_MIN) / 3;
    K_phi[1][2] = fmax(nodes[1] - nodes[0], h_MIN) / 6;
    for(i=2; i<=n-1; i++){
        K_phi[i][0] = fmax(nodes[i-1] - nodes[i-2], h_MIN) / 6;
        K_phi[i][1] = fmax(nodes[i] - nodes[i-2], 2*h_MIN) / 3;
        K_phi[i][2] = fmax(nodes[i] - nodes[i-1], h_MIN) / 6;
    }
    K_phi[n][0] = fmax(nodes[n-1] - nodes[n-2], h_MIN) / 6;
    K_phi[n][1] = fmax(1 - nodes[n-2], 2*h_MIN) / 3;
    K_phi[n][2] = fmax(1 - nodes[n-1], h_MIN) / 6;
    K_phi[n+1][0] = fmax(1 - nodes[n-1], h_MIN) / 6;
    K_phi[n+1][1] = fmax(1 - nodes[n-1], h_MIN) / 3;
    
    // Neumann BVP
    if(nn->problem_type != 0){
        K_phi[0][1] += 1.0 / fmax(nodes[0], h_MIN);
        K_phi[0][2] += -1.0 / fmax(nodes[0], h_MIN);
        K_phi[1][0] += -1.0 / fmax(nodes[0], h_MIN);
        K_phi[1][1] += 1.0 / fmax(nodes[0], h_MIN) + 1.0 / fmax(nodes[1] - nodes[0], h_MIN);
        K_phi[1][2] += -1.0 / fmax(nodes[1] - nodes[0], h_MIN);
        for(i=2; i<=n-1; i++){
            K_phi[i][0] += -1.0 / fmax(nodes[i-1] - nodes[i-2], h_MIN);
            K_phi[i][1] += 1.0 / fmax(nodes[i-1] - nodes[i-2], h_MIN) + 1.0 / fmax(nodes[i] - nodes[i-1], h_MIN);
            K_phi[i][2] += -1.0 / fmax(nodes[i] - nodes[i-1], h_MIN);
        }
        K_phi[n][0] += -1.0 / fmax(nodes[n-1] - nodes[n-2], h_MIN);
        K_phi[n][1] += 1.0 / fmax(nodes[n-1] - nodes[n-2], h_MIN) + 1.0 / fmax(1 - nodes[n-1], h_MIN);
        K_phi[n][2] += -1.0 / fmax(1 - nodes[n-1], h_MIN);
        K_phi[n+1][0] += -1.0 / fmax(1 - nodes[n-1], h_MIN);
        K_phi[n+1][1] += 1.0 / fmax(1 - nodes[n-1], h_MIN);
    }

    free_mat(C1, n+2), free_vec(nodes);
    return;
}

// 1D Preconditioner: Multiplication by bar{P}
// Inputs - nn: shallow_nn structure, C1_inv[n+2][3], C2[n+2][3], K_phi[n+2][3], aa[n+2]
// Outputs - bb[n+2]: bb = bar{P} * aa
static void bP_mult(shallow_nn *nn, double **C1_inv, double **C2, double **K_phi, double *aa, double *bb){
    int n = nn->n;
    int i = 0, j = 0;
    double *cc = new_vec(n+2);

    // Multiplication by C1_inv
    cc[0] = C1_inv[0][1]*aa[0];
    for(i=1; i<=n; i++){
        cc[i] = C1_inv[i][0]*aa[0] + C1_inv[i][1]*aa[i] + C1_inv[i][2]*aa[n+1];
    }
    cc[n+1] = C1_inv[n+1][1]*aa[n+1];

    // Multiplication by C2
    bb[0] = C2[0][0]*cc[0] + C2[0][1]*cc[1] + C2[0][2]*cc[n+1];
    for(i=1; i<=n-1; i++){
        bb[i] = C2[i][0]*cc[i-1] + C2[i][1]*cc[i] + C2[i][2]*cc[i+1];
    }
    bb[n] = C2[n][0]*cc[n-1] + C2[n][1]*cc[n];
    bb[n+1] = C2[n+1][0]*cc[n];

    // Multiplication by K_phi^-1
    mat_trisolve(K_phi, bb, bb, n+2);
    // Multiplication by diag(K_phi)^-1
    /*
    for(i=0; i<n+2; i++){
        bb[i] /= K_phi[i][1];
    }
    */

    // Multiplication by C2^T
    cc[0] = C2[0][0]*bb[0] + C2[1][0]*bb[1];
    cc[1] = C2[0][1]*bb[0] + C2[1][1]*bb[1] + C2[2][0]*bb[2];
    for(i=2; i<=n; i++){
        cc[i] = C2[i-1][2]*bb[i-1] + C2[i][1]*bb[i] + C2[i+1][0]*bb[i+1];
    }
    cc[n+1] = C2[0][2]*bb[0];

    // Multiplication by C1_inv^T
    bb[0] = C1_inv[0][1]*cc[0];
    bb[n+1] = C1_inv[n+1][1]*cc[n+1];
    for(i=1; i<=n; i++){
        bb[0] += C1_inv[i][0]*cc[i];
        bb[i] = C1_inv[i][1]*cc[i];
        bb[n+1] += C1_inv[i][2]*cc[i];
    }

    free_vec(cc);
    return;
}

// 1D Preconditioner: Multiplication by P
// Inputs - nn: shallow_nn structure, C1_inv[n+2][3], C2[n+2][3], K_phi[n+2][3], a[n]
// Outputs - b[n+2]: b = P * a
static void P_mult(shallow_nn *nn, double **C1_inv, double **C2, double **K_phi, double *a, double *b){
    int n = nn->n;
    double p = 0, q = 0;

    int i = 0;
    double temp1 = 0, temp2 = 0, temp3 = 0;
    double *temp_vec1, *temp_vec2;

    // Computation of p and q by solving a linear system with two unknowns
    temp_vec1 = new_vec(n+2);
    temp_vec2 = new_vec(n+2);
    temp_vec1[0] = 1;
    temp_vec2[n+1] = 1;
    bP_mult(nn, C1_inv, C2, K_phi, temp_vec1, temp_vec1);
    bP_mult(nn, C1_inv, C2, K_phi, temp_vec2, temp_vec2);
    for(i=0; i<n; i++){
        temp1 -= temp_vec1[i+1]*a[i];
        temp2 -= temp_vec2[i+1]*a[i];
    }
    temp3 = 1 / (temp_vec1[0]*temp_vec2[n+1] - temp_vec1[n+1]*temp_vec2[0]);
    p = temp3 * (temp_vec2[n+1]*temp1 - temp_vec2[0]*temp2);
    q = temp3 * (-temp_vec1[n+1]*temp1 + temp_vec1[0]*temp2);
    
    // Multiplication by bP
    temp_vec1[0] = p;
    for(i=0; i<n; i++){
        temp_vec1[i+1] = a[i];
    }
    temp_vec1[n+1] = q;
    bP_mult(nn, C1_inv, C2, K_phi, temp_vec1, temp_vec2);

    // Multiplication by R
    for(i=0; i<n; i++){
        b[i] = temp_vec2[i+1];
    }

    free_vec(temp_vec1), free_vec(temp_vec2);
    return;
}

// 1D Preconditioner: Preconditioned conjugate gradient method
// Inputs - nn: shallow_nn structure, C1_inv[n+2][3], C2[n+2][3], K_phi[n+2][3], K[n][n]: mass matrix, f[n]: load vector
// Outputs - c[n]: solution
static void P_pcg(shallow_nn *nn, double **C1_inv, double **C2, double **K_phi, double **K, double *f, double *c){
    int n = nn->n;
	double rs0 = 0, rsold = 0, rsnew = 0, rzold = 0, rznew = 0, alpha = 0, beta = 0;
	double *r, *p, *Mp, *z;

	int iter = 0, ITER_MAX = 20;
	double tol = pow(0.1, 10);

	double *temp_vec1;
	r = new_vec(n);
	p = new_vec(n);
	Mp = new_vec(n);
	z = new_vec(n);
	temp_vec1 = new_vec(n);

	mat_vec_mult(K, c, temp_vec1, n, n);
	vec_lincom(f, temp_vec1, 1, -1, r, n);
	
    P_mult(nn, C1_inv, C2, K_phi, r, z);
	vec_cpy(z, p, n);
	rsold = vec_innerprod(r, r, n);
	rzold = vec_innerprod(r, z, n);
	rs0 = rsold;

	for (iter = 1; iter <= ITER_MAX; iter++)
	{
		mat_vec_mult(K, p, Mp, n, n);
		alpha = rzold / vec_innerprod(p, Mp, n);
		vec_lincom(p, p, alpha, 0, p, n);
		vec_lincom(Mp, Mp, alpha, 0, Mp, n);
		vec_lincom(c, p, 1, 1, c, n);
		vec_lincom(r, Mp, 1, -1, r, n);
		rsnew = vec_innerprod(r, r, n);

		// Check the stop condtiion
		if (rsnew / rs0 < tol) {
			//printf("PCG converged with %d iterations.\n", iter);
			break;
		}
        //printf("Iteration %d: %.10e\n", iter, rsnew);

		P_mult(nn, C1_inv, C2, K_phi, r, z);
		rznew = vec_innerprod(r, z, n);
		beta = rznew / rzold;
		vec_lincom(p, p, beta / alpha, 0, p, n);
		vec_lincom(z, p, 1, 1, p, n);
		rsold = rsnew;
		rzold = rznew;
	}
	if(iter > ITER_MAX){
		printf("Warning: PCG iterated MAX_ITER iterations but did not converge.\n");
	}

	free_vec(r), free_vec(z), free_vec(p), free_vec(Mp), free_vec(temp_vec1);
	return;
}


// Halton sequence of dimension d and length N
// Inputs - d: dimension, N: length
// Outputs - x[N][d]
void halton_seq(int d, int N, double **x){
    int p_MAX = 10;
    int primes[p_MAX];
    int i = 0, j = 0, k = 0, temp = 0, logi = 0;

    if(d > p_MAX){
        printf("The case d > %d is not supported yet.\n", p_MAX);
        return;
    }
    // First primes
    primes[0] = 2;
    primes[1] = 3;
    primes[2] = 5;
    primes[3] = 7;
    primes[4] = 11;
    primes[5] = 13;
    primes[6] = 17;
    primes[7] = 19;
    primes[8] = 23;
    primes[9] = 29;

    for(j = 0; j< d; j++){
        for(i = 0; i < N; i++){
            logi = (int) floor(log(i+1)/log(2)) + 1;
            temp = i + 1;
            x[i][j] = 0;
            for(k = 0; k < logi; k++){
                x[i][j] += (temp - (temp/primes[j]) * primes[j]) * pow((double) primes[j], -(k+1));
                temp  = temp/primes[j];
            }
        }
    }

    return;
}

// Normal random number generation by the Marsaglia polar method
// Inputs - mu: mean, sigma: standard deviation
double normrnd(double mu, double sigma){
    static double spare = 0;
    static int has_spare = 0;
    double u = 0, v = 0, s = 0;

    if(has_spare == 1){
        has_spare = 0;
        return spare * sigma + mu;
    }

    else{
        do{
            u = (rand() / ((double) RAND_MAX)) * 2 - 1;
            v = (rand() / ((double) RAND_MAX)) * 2 - 1;
            s = u*u + v*v;
        } while(s >= 1 || s == 0);
        s = sqrt(-2.0 * log(s) / s);
        spare = v * s;
        has_spare = 1;
        return u * s * sigma + mu;       
    }
}

// L^2-inner product of two functions f and g defined on the grid x
// Inputs - f[N], g[N]: function values, nn: shallow_nn structure
// Outputs - ret: L^2-inner product of f and g
double L2_innerprod(double *f, double *g, shallow_nn *nn){
    double dx = 0, ret = 0;
    int N = nn->N, d = nn->d;

    int i = 0;

    if(d == 1){
        // d = 1: trapezoidal rule
        for(i=0; i<N-1; i++){
            dx = nn->x[i+1][0] - nn->x[i][0];
            ret += 0.5 * dx * (f[i]*g[i] + f[i+1]*g[i+1]);
        }
    }
    else{
        // d > 1: quasi-Monte Carlo integration
        for(i=0; i<N; i++){
            ret += f[i]*g[i];
        }
        ret /= N;
    }

    return ret;
}

// a-inner product of two functions f and g defined on the grid x
// Inputs - f[N], g[N]: function values, grad_f[d][N], grad_g[d][N]: gradient values, nn: shallow_nn structure
// Outputs - ret: a-inner product of f and g
double a_innerprod(double *f, double **grad_f, double *g, double **grad_g, shallow_nn *nn){
    double  ret = 0;
    int d = nn->d;
    int k = 0;

    ret = L2_innerprod(f, g, nn);
    // L^2 function approximation
    if(nn->problem_type == 0){
        return ret;
    }

    // Neumann BVP
    for(k=0; k<d; k++){
        ret += L2_innerprod(grad_f[k], grad_g[k], nn);        
    }

    return ret;
}
