#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <mpi.h>
#include "linalg.h"
#include "nn.h"
#include "schwarz.h"

// Local residual for NPSC
static double local_loss(shallow_nn *nn, double *c, double *theta, schwarz *sc, int s, double *rs, int nograd, double *u_nn, double **dudr, double **grad_u_nn, double ***grad_dudr);

// New schwarz structure
// Inputs - n: number of neurons, nprocs, pid
// Outputs - sc: schwarz_structure
schwarz *new_schwarz(int n, int nprocs, int pid){
    int s = 0;
    
    schwarz *sc = (schwarz *) malloc(sizeof(schwarz));

    // MPI variables
    sc->nprocs = nprocs;
    sc->pid = pid;

	// Neuron indices of each process determined by the para-range subroutine
	sc->ista = (int *)malloc(nprocs * sizeof(int));
	sc->iend = (int *)malloc(nprocs * sizeof(int));
	sc->isize = (int *)malloc(nprocs * sizeof(int));
    // s = 0: master node
    sc->ista[ROOT] = -1;
    sc->iend[ROOT] = -2;
    sc->isize[ROOT] = 0;
    for (s = 1; s < nprocs; s++)
	{
		para_range(n, nprocs-1, s-1, &(sc->ista[s]), &(sc->iend[s]));
		sc->isize[s] = sc->iend[s] - sc->ista[s] + 1;
	}

    return sc;
}

// Free schwarz structure
// Inputs - sc: schwarz structure
void free_schwarz(schwarz *sc){
    free(sc->ista), free(sc->iend), free(sc->isize);
    free(sc);
}

// Train the neural network by NPSC
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: initial parameters, sc: schwarz structure, num_epochs: number of epochs
// Outputs - c[n], theta[(d+1)*n]: parameters, ret: final loss, nn: shallow_nn structure
//           E_decay[num_epochs + 1]: loss decay
double train_NPSC(shallow_nn *nn, double *c, double *theta, schwarz *sc, int num_epochs, double *E_decay){
    double E = 0, E_old = 0, E_c = 0;
    int d = nn->d, n = nn->n;

    int pid = sc->pid;

    // Initial learning rate
    double lr = 1, MIN_LR = pow(0.1, 12);
    double *theta_old, *dEda, *dEdtheta;
    
    int iter = 0, i = 0;

    theta_old = new_vec((d+1)*n);
    dEda = new_vec(n);
    dEdtheta = new_vec((d+1)*n);

    E = loss_fun(nn, c, theta, 1, dEda, dEdtheta);

    // NPSC iterations
    for(iter=0; iter<num_epochs; iter++){
        // Update the iterates
        vec_cpy(theta, theta_old, (d+1)*n);
        E_old = E;
        E_decay[iter] = E_old;
        if(pid == ROOT){
            printf("Epoch %d: %.6e, LR: %.6e\n", iter, E_decay[iter], lr);
        }

        // Adjustment step for theta
        adjust_theta(nn, c, theta_old);

        // Find the optimal c when theta is fixed.
        c_minimize(nn, c, theta_old);
        E_c = loss_fun(nn, c, theta_old, 1, dEda, dEdtheta);

        // Solve the local problems.
        local_solver(nn, c, theta_old, sc, dEdtheta);
        
        // Backtracking for the learning rate
        lr *= 2;
        while(lr > MIN_LR){
            vec_lincom(theta_old, dEdtheta, 1, lr, theta, (d+1)*n);
            E = loss_fun(nn, c, theta, 1, dEda, dEdtheta);
            // Backtracking criterion
            if(E <= E_c){
                break;
            }
            lr /= 2;
        }
    }

    for(i=iter; i<=num_epochs; i++){
        E_decay[i] = E;
    }
    if(pid == ROOT){
        printf("Epoch %d: %.6e, LR: %.6e\n", num_epochs, E, lr);
    }

    free_vec(theta_old);
    free_vec(dEda), free_vec(dEdtheta);
    return E;
}

// Adjustment step for theta: No dead neurons are made.
// Inputs - nn: shallow_nn structure, theta[(d+1)*n]
// Outputs - theta[(d+1)*n]
void adjust_theta(shallow_nn *nn, double *c, double *theta){
    int d = nn->d, n = nn->n;
    int i = 0, j = 0, k = 0;
    double temp = 0;
    double dx = 0;

    // Variables for b-initialization
    int num_vertices = 1, vertex_id = 0;
    double wx_max = 0, wx_min = 0;
    double **vertices;

    // Variables for sorting
    double *nodes;

    for(j=0; j<d; j++){
        num_vertices *= 2;
    }
    vertices = new_mat(num_vertices, d);
    for(vertex_id=0; vertex_id<num_vertices; vertex_id++){
        k = vertex_id;
        // Generate vertices from binary numbers.
        for(j=0; j<d; j++){
            vertices[vertex_id][j] = (double) (k - 2*(k/2));
            k = k/2;
        }
    }

    for(i=0; i<n; i++){
        // Normalization of (w,b) so that |w| = 1
        temp = 0;
        for(j=0; j<d; j++){
            temp += theta[(d+1)*i + j] * theta[(d+1)*i + j];
        }
        temp = sqrt(temp);

        c[i] *= temp;
        for(j=0; j<d+1; j++){
            theta[(d+1)*i + j] /= temp;
        }

        // Relocation of b
        wx_max = -10000;
        wx_min = 10000;
        for(vertex_id=0; vertex_id<num_vertices; vertex_id++){
            temp = 0;
            for(j=0; j<d; j++){
                temp += theta[(d+1)*i + j] * vertices[vertex_id][j];
            }
            if(temp > wx_max){
                wx_max = temp;
            }
            if(temp < wx_min){
                wx_min = temp;
            }
        }
        if(theta[(d+1)*i + d] < -wx_max || theta[(d+1)*i + d] > -wx_min)
            theta[(d+1)*i + d] = -wx_max + ((double) rand() / RAND_MAX) * (-wx_min + wx_max);
    }

    // If d = 1, sort the neurons
    if(d == 1){
        nodes = new_vec(n);
        for(i=0; i<n; i++){
            nodes[i] = - theta[2*i+1] / theta[2*i];
        }

        // Sequential sort
        for(i=0; i<n; i++){
            for(j=i+1; j<n; j++){
                if(nodes[i] > nodes[j]){
                    temp = c[i];
                    c[i] = c[j];
                    c[j] = temp;
                    temp = theta[2*i];
                    theta[2*i] = theta[2*j];
                    theta[2*j] = temp;
                    temp = theta[2*i+1];
                    theta[2*i+1] = theta[2*j+1];
                    theta[2*j+1] = temp;
                    temp = nodes[i];
                    nodes[i] = nodes[j];
                    nodes[j] = temp;
                }
            }
        }

        // Avoid too-near nodes (heuristic)
        for(i=0; i<n; i++){
            dx = nn->x[1][0] - nn->x[0][0];
            if(nodes[i+1] - nodes[i] < dx){
                nodes[i+1] += dx;
            }
        }
        for(i=0; i<n; i++){
            theta[2*i+1] = - theta[2*i] * nodes[i];
        }

        free_vec(nodes);
    }

    free_mat(vertices, num_vertices);
}

// Local solver for NPSC
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: parameters, sc: schwarz structure
// Outputs - r: assembled correction (sum of all local corrections)
void local_solver(shallow_nn *nn, double *c, double *theta, schwarz *sc, double *r){
    int d = nn->d, n = nn->n, N = nn->N;
    int pid = sc->pid;

    // Local corrections
    double **rs;

    // Variables for the Levenberg-Marquardt algorithm
    double lambda = 1, MAX_LAMBDA = pow(10, 12), E_old = 0, E = 0, E0 = 0, E1 = 0;
    double *rs_old, *dr0, *dr1;
    double *u_nn, **dudr, **grad_u_nn, ***grad_dudr, **JKJ, **JKJ_lambda, *RHS;
    int iter = 0, ITER_MAX = 20;
    double tol = pow(0.1, 10);

    int s = 0, i = 0, j = 0;
    double *temp_vec;

    // Initialization of local corrections
    rs = (double **) malloc(sc->isize[pid] * sizeof(double *));
    for(s=sc->ista[pid]; s<=sc->iend[pid]; s++){
        rs[s - sc->ista[pid]] = new_vec(d+1);
    }

    // Solving local problems
    for(s=sc->ista[pid]; s<=sc->iend[pid]; s++){
        // Levenberg-Marquardt algorithm
        lambda = 1;
        rs_old = new_vec(d+1);
        u_nn = new_vec(N);
        dudr = new_mat(d+1, N);
        grad_u_nn = new_mat(d, N);
        grad_dudr = new_ten(d+1, d, N);
        dr0 = new_vec(d+1);
        dr1 = new_vec(d+1);
        JKJ = new_mat(d+1, d+1);
        JKJ_lambda = new_mat(d+1, d+1);
        RHS = new_vec(d+1);

        for(iter=1; iter<=ITER_MAX; iter++){
            vec_cpy(rs[s - sc->ista[pid]], rs_old, d+1);
            E_old = local_loss(nn, c, theta, sc, s, rs_old, 0, u_nn, dudr, grad_u_nn, grad_dudr);

            for(i=0; i<d+1; i++){
                // Assembly of JKJ
                for(j=i; j<d+1; j++){
                    JKJ[i][j] = a_innerprod(dudr[i], grad_dudr[i], dudr[j], grad_dudr[j], nn);
                    JKJ[j][i] = JKJ[i][j];
                }
                // Assembly of RHS
                RHS[i] = L2_innerprod(dudr[i], nn->f, nn) - a_innerprod(dudr[i], grad_dudr[i], u_nn, grad_u_nn, nn);
            }

			// Gradient descent algorithm
			while(lambda > 0.00001){
				lambda *= 2;
				vec_lincom(rs_old, RHS, 1, lambda, rs[s - sc->ista[pid]], d+1);
				E = local_loss(nn, c, theta, sc, s, dr0, 1, u_nn, dudr, grad_u_nn, grad_dudr);
				if(E <= E_old)
					break;
				lambda /= 2;
			}

			/*
			// Levenberg--Marquardt algorithm
            // Computation of dr0
            mat_cpy(JKJ, JKJ_lambda, d+1, d+1);
            for(i=0; i<d+1; i++){
                // 0.0001 is due to avoid singularity.
                JKJ_lambda[i][i] = JKJ[i][i] + lambda * fmax(JKJ[i][i], 0.0001);
            }
            mat_solve(JKJ_lambda, RHS, dr0, d+1);
            vec_lincom(rs_old, dr0, 1, 1, dr0, d+1);
            E0 = local_loss(nn, c, theta, sc, s, dr0, 1, u_nn, dudr, grad_u_nn, grad_dudr);

            // Computation of dr1
            for(i=0; i<d+1; i++){
                // 0.0001 is due to avoid singularity.
                JKJ_lambda[i][i] = JKJ[i][i] + lambda/2 * fmax(JKJ[i][i], 0.0001);
            }
            mat_solve(JKJ_lambda, RHS, dr1, d+1);

            vec_lincom(rs_old, dr1, 1, 1, dr1, d+1);
            E1 = local_loss(nn, c, theta, sc, s, dr1, 1, u_nn, dudr, grad_u_nn, grad_dudr);

            // Case 1
            if(E0 > E_old && E1 > E_old){
                while(lambda < MAX_LAMBDA){
                    lambda *= 2;
                    mat_cpy(JKJ, JKJ_lambda, d+1, d+1);
                    for(i=0; i<d+1; i++){
                        JKJ_lambda[i][i] *= (1 + lambda);
                    }
                    mat_solve(JKJ_lambda, RHS, dr0, d+1);
                    vec_lincom(rs_old, dr0, 1, 1, dr0, d+1);
                    E0 = local_loss(nn, c, theta, sc, s, dr0, 1, u_nn, dudr, grad_u_nn, grad_dudr);
                    
                    if(E0 <= E_old)
                        break;
                }

                vec_cpy(dr0, rs[s - sc->ista[pid]], d+1);
                E = E0;
            }
            else{
                // Case 2
                if(E0 < E1){
                    vec_cpy(dr0, rs[s - sc->ista[pid]], d+1);
                    E = E0;
                }
                // Case 3
                else{
                    vec_cpy(dr1, rs[s - sc->ista[pid]], d+1);
                    lambda /= 2;
                    E = E1;
                }
            }
			*/

            // Stopping criterion
            if(fabs((E_old - E) / E) < tol){
                //printf("LMA converged with %d iterations.\n", iter);
                 break;
            }           
        }
        if(iter > ITER_MAX){
		    printf("Warning: LMA iterated MAX_ITER iterations but did not converge. %.6e\n", fabs(E_old - E));
	    }

        free_vec(rs_old), free_vec(u_nn), free_mat(dudr, d+1), free_mat(grad_u_nn, d), free_ten(grad_dudr, d+1, d);
        free_vec(dr0), free_vec(dr1);
        free_mat(JKJ, d+1), free_mat(JKJ_lambda, d+1), free_vec(RHS);
    }

    // Assemble rs into r.
    temp_vec = new_vec((d+1)*n);
    for(s=sc->ista[pid]; s<=sc->iend[pid]; s++){
        for(i=(d+1)*s; i<=(d+1)*(s+1) - 1; i++){
            temp_vec[i] += rs[s - sc->ista[pid]][i - (d+1)*s];
        }
    }
    MPI_Allreduce(temp_vec, r, (d+1)*n, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
    free_vec(temp_vec);

    free_mat(rs, sc->isize[pid]);
    return;
}

// Local residual for NPSC
// Inputs - nn: shallow_nn structure, c[n], theta[(d+1)*n]: parameters, sc: schwarz structure, s: subdomain id, rs[d+1]: current local correction, nograd: not compute gradient if nograd != 0
// Outputs - u_nn[N]: NN function, dudr[d+1][N]: r-Jacobian, grad_u_nn[d][N], grad_dudr[d+1][d][N]: gradients, E: loss function
static double local_loss(shallow_nn *nn, double *c, double *theta, schwarz *sc, int s, double *rs, int nograd, double *u_nn, double **dudr, double **grad_u_nn, double ***grad_dudr){
    int d = nn->d, N = nn->N, n = nn->n;
    
    double *thetar = new_vec((d+1)*n);
    int thetasta = (d+1)*s, thetaend = (d+1)*(s+1) - 1;
    double **dudc = new_mat(n, N), **dudtheta = new_mat((d+1)*n, N);
    double ***grad_dudc = new_ten(n, d, N), ***grad_dudtheta = new_ten((d+1)*n, d, N);

    int i = 0, k = 0;

    // thetar = theta + R_s^T r_s
    vec_cpy(theta, thetar, (d+1)*n);
    for(i=thetasta; i<=thetaend; i++){
        thetar[i] += rs[i - thetasta];
    }

    nn_fun(nn, c, thetar, u_nn, dudc, dudtheta, grad_u_nn, grad_dudc, grad_dudtheta);

    if(nograd != 0){
        goto local_loss_end;
    }

    for(i=thetasta; i<=thetaend; i++){
        // dudr
        vec_cpy(dudtheta[i], dudr[i-thetasta], N);

        // grad_dudr
        if(nn->problem_type != 0){
            for(k=0; k<d; k++){
                vec_cpy(grad_dudtheta[i][k], grad_dudr[i-thetasta][k], N);
            }
        }
    }

    local_loss_end:
    free_vec(thetar);
    free_mat(dudc, n), free_mat(dudtheta, (d+1)*n);
    free_ten(grad_dudc, n, d), free_ten(grad_dudtheta, (d+1)*n, d);
    return 0.5 * a_innerprod(u_nn, grad_u_nn, u_nn, grad_u_nn, nn) - L2_innerprod(nn->f, u_nn, nn);
}

// Para-range subroutine (Distribute 0, ..., n-1 into nblocks blocks)
// Inputs - n, nblocks, blockid
// Outputs - ista, iend
void para_range(int n, int nblocks, int blockid, int *ista, int *iend)
{
	int iwork1 = n / nblocks, iwork2 = n % nblocks;

	if (blockid < iwork2)
		*ista = blockid * iwork1 + blockid;
    else
		*ista = blockid * iwork1 + iwork2;

	*iend = *ista + iwork1 - 1;
	if (blockid < iwork2)
		*iend = *iend + 1;

	return;
}
