#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include "linalg.h"
#include "nn.h"

// Label function f(x)
// Inputs: x[d]: sample point
double fun_f(double *x){
    //return sin(2*M_PI * x[0]);
    /*
    if(x[0] < 0.5){
        return 10 * (sin(2*M_PI*x[0]) + sin(6*M_PI*x[0]));
    }
    else{
        return 10 * (sin(8*M_PI*x[0]) + sin(18*M_PI*x[0]) + sin(26*M_PI*x[0]));
    }
    */
    //return (1 + M_PI*M_PI) * cos(M_PI * x[0]);
    return (1 + 2*M_PI*M_PI) * cos(M_PI * x[0]) * cos(M_PI * x[1]);
}

// Exact solution u(x)
double fun_u(double *x){
    //return sin(2*M_PI * x[0]);
    /*
    if(x[0] < 0.5){
        return 10 * (sin(2*M_PI*x[0]) + sin(6*M_PI*x[0]));
    }
    else{
        return 10 * (sin(8*M_PI*x[0]) + sin(18*M_PI*x[0]) + sin(26*M_PI*x[0]));
    }
    */
    //return cos(M_PI * x[0]);
    return cos(M_PI * x[0]) * cos(M_PI * x[1]);
}

// ReLU activaction function
double relu(double x){
    if(x > 0){
        return x;
    }
    else{
        return 0;
    }
}

// Derivative of the ReLU activation function
double d_relu(double x){
    if(x > 0){
        return 1;
    }
    else{
        return 0;
    }
}

// Main function
int main(int argc, char **argv){
    // shallow_nn structure
    int d = 0, N = 0, n = 0, problem_type = 0, num_epochs = 0;
    shallow_nn *nn;

    // c[n], theta[(d+1)*n]: NN parameters (theta = (W, b))
    double *c, *theta;

    // For reconstruction of the solution
    double *u_nn, **dudc, **dudtheta, **grad_u_nn, ***grad_dudc, ***grad_dudtheta;

    double E = 0;
    double *E_decay;

    // Random seed
    int seed = 0;

    int i = 0;

    FILE *fp;
    char file_name[50];

    if(argc != 3){
        printf("Input format: ./file_name problem_type n\n");
        printf("problem_type: 0 - Function approximation, 1 - Neumann BVP\n");
        printf("n: Number of neurons\n");
        return 0;
    }

    d = 2;
    N = 10000;
    n = atoi(argv[2]);
    problem_type = atoi(argv[1]);
    num_epochs = 20;

    // Construction of shallow_nn structure
    nn = new_shallow_nn(d, N, n, problem_type, &fun_f, &fun_u, &relu, &d_relu);

    for(seed=1; seed<=10; seed++){
        srand(2022 * seed);
        printf("LSGD - Seed %d\n", seed);

        c = new_c_He(nn);
        theta = new_theta_He(nn);

        E_decay = new_vec(num_epochs + 1);
        E = train_LSGD(nn, c, theta, num_epochs, E_decay);

        // E = -0.25;  // Example 1
        // E = -0.5 * L2_innerprod(nn->u, nn->u, nn);  // Example 2
        // E = -0.5 - M_PI*M_PI/4;  // Example 3
        E = -0.5 * L2_innerprod(nn->f, nn->u, nn);  // Example 4
        
        sprintf(file_name, "Results_Ex4/LSGD_%02d_%02d.txt", n, seed);
        fp = fopen(file_name, "w");
        for(i=0; i<=num_epochs; i++){
            fprintf(fp, "%.10e\n", E_decay[i] - E);
        }
        fclose(fp);

        /*
        // Reconstruction of the solution
        u_nn = new_vec(N);
        dudc = new_mat(n, N);
        dudtheta = new_mat((d+1)*n, N);
        grad_u_nn = new_mat(d, N);
        grad_dudc = new_ten(n, d, N);
        grad_dudtheta = new_ten((d+1)*n, d, N);
        nn_fun(nn, c, theta, u_nn, dudc, dudtheta, grad_u_nn, grad_dudc, grad_dudtheta);

        free_vec(u_nn), free_mat(dudc, n), free_mat(dudtheta, (d+1)*n);
        free_mat(grad_u_nn, d), free_ten(grad_dudc, n, d), free_ten(grad_dudtheta, (d+1)*n, d);
        */

        free_vec(c), free_vec(theta);
        free_vec(E_decay);
    }

    free_shallow_nn(nn);
    return 0;
}

