/* DOES NOT WORK PROPERLY, AS GETTING OPENMP SUPPORT IN MATLAB IS DIFFICULT. */
/*mex -lmwlapack -lmwblas -largeArrayDims partXY_blas.c */

/*
 *   TTeMPS Toolbox. 
 *   Michael Steinlechner, 2013-2016
 *   Questions and contact: michael.steinlechner@epfl.ch
 *   BSD 2-clause license, see LICENSE.txt
 */

#define U_SLICE(i,j) &U[i][(ind[d*j+i]-1)*r[i]*r[i+1]]
#define V_SLICE(i,j) &V[i][(ind[d*j+i]-1)*r[i]*r[i+1]]
/*#define RES_SLICE(i,j) &result[i][(ind[d*j+i]-1)*r[i]*r[i+1]]*/
#define RES_SLICE(i,j) &result_part[i][(ind[d*j+i]-1)*r[i]*r[i+1]]

#include "mex.h"
#include "blas.h"
#include <omp.h>

/* calling: 
	TTeMPS_tangent_omega( n, r, Cores, ind, vals)
*/
void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[] ) {

	/* input variables */
	double* n_raw;
	double* r_raw;
	double** U;
	double** V;
	double* ind_raw;
	double* vals;
	
	/* output variables */
	double** result;
	mxArray** result_cells;
	
	/* internal variables */
	double* L;
	double* current;
	double* tmp;
	double** result_part;
	
	mwSignedIndex* n;
	mwSignedIndex* r;
	mwSignedIndex* ind;

	mwSignedIndex numSubsref;
	mwSignedIndex d;
	mwSignedIndex i;
	mwSignedIndex j;
	mwSignedIndex k;
	mwSignedIndex coresize;
	mwSignedIndex maxrank = 1;
					

	/* get sizes */
	n_raw = mxGetPr( prhs[0] );
	/* get ranks */
	r_raw = mxGetPr( prhs[1] );
	/* get indices */
	ind_raw = mxGetPr( prhs[4] );
	d = mxGetM( prhs[4] );
	numSubsref = mxGetN( prhs[4] );
	vals = mxGetPr( prhs[5] );
	
	n = mxMalloc( d*sizeof(mwSignedIndex) );
	r = mxMalloc( (d+1)*sizeof(mwSignedIndex) );
	ind = mxMalloc( d*numSubsref*sizeof(mwSignedIndex) );
	
	/* Convert index arrays to integer arrays as they get converted
	 * to double arrays when passing to mex.
	 * Converting beforehand allows to avoid multiple typecasts inside the inner loop */
	for( i = 0; i < d; ++i ) {
		n[i] = (mwSignedIndex) n_raw[i];
		r[i] = (mwSignedIndex) r_raw[i];
		if( r[i] > maxrank )
			maxrank = r[i];
	}
	r[d] = (mwSize) r_raw[d];
	
	for( i = 0; i < numSubsref*d; ++i ) {
		ind[i] = (mwSignedIndex) ind_raw[i];
	}
	

	/* Get pointers to the matrices within the cell array */
	U = mxMalloc( d*sizeof(double*) );
	V = mxMalloc( d*sizeof(double*) );
	
	for( i = 0; i < d; ++i ) {
		U[i] = mxGetPr( mxGetCell( prhs[2], i ) );
    	V[i] = mxGetPr( mxGetCell( prhs[3], i ) );
	}
	
	/* Allocate space for output */
	plhs[0] = mxCreateCellMatrix( 1, d );
	result_cells = mxMalloc( d*sizeof(mxArray*) );
	result = mxMalloc( d*sizeof(double*) );
	
	for( i=0; i < d; i++){
		result_cells[i] = mxCreateDoubleMatrix( r[i]*r[i+1]*n[i], 1, mxREAL);
		result[i] = mxGetPr( result_cells[i] );
		mxSetCell( plhs[0], i, result_cells[i] );
	}
	
	/* helper variables for dgemv call */
	char transa = 'T';
	char no_transa = 'N';
	mwSignedIndex ONE_i = 1;
	double ONE_d = 1.0;
	double ZERO_d = 0.0;

    /*#pragma omp parallel shared(n,r,d,ind,result,coresize,U,V) private(i,j,L,current,tmp,result_part)*/
    #pragma omp parallel default(none) \
            shared(n,r,d,ind,result,coresize,U,V, numSubsref, maxrank,ONE_i, ZERO_d, ONE_d, transa, vals, no_transa)\
            private(i,j,k,L,current,tmp,result_part)
    {
        /* Allocate enough space for internal intermediate results */
        /*L = malloc( maxrank*(d-1)*sizeof(double) );
        current = malloc( maxrank*sizeof(double) );
        tmp = malloc( maxrank*sizeof(double) );*/
#pragma omp critical
        {
        L = mxCalloc( maxrank*(d-1), sizeof(double) );
        current = mxCalloc( maxrank, sizeof(double) );
        tmp = mxCalloc( maxrank, sizeof(double) );
    

        result_part = mxMalloc( d*sizeof(double*) );
        for( i = 0; i < d; ++i )
            result_part[i] = mxCalloc( r[i]*r[i+1]*n[i], sizeof(double) );
        }
    
        #pragma omp for
        for( j = 0; j < numSubsref; ++j ) {
            
            /* LEFT TO RIGHT FIRST (PRECOMPUTE)*/
            /* ... copy first core to L: */
            dcopy( &r[1], U_SLICE(0,j), &ONE_i, &L[0], &ONE_i );
            /* ... and then multiply with the other cores and store results in columns of L: */
            for( i = 1; i < d-1; ++i ) {
                dgemv( &transa, &r[i], &r[i+1], &ONE_d, 
                        U_SLICE(i,j), 
                        &r[i],   
                        &L[maxrank*(i-1)], 
                        &ONE_i, &ZERO_d, &L[maxrank*i], &ONE_i);
            }
            
            /* RIGHT TO LEFT PRODUCTS NOW -- USING PRECOMPUTED LEFT SIDES FROM ABOVE */
            /* last dU is without any contributions from the right */
            daxpy( &r[d-1], &vals[j], &L[maxrank*(d-2)], &ONE_i, RES_SLICE(d-1,j), &ONE_i );
        
            /* copy rightmost slice to current variable */
            dcopy( &r[d-1], V_SLICE(d-1,j), &ONE_i, current, &ONE_i );
            
            /* sweep right-left to form dU{i-1} to dU{1} */
            for( i = d-2; i > 0; --i ) {
                /* Outer product update: 
                 * result(:,:,idx) = result(:,:,idx) + L(1:r(i), i-1)*current' */
                dger( &r[i], &r[i+1], &vals[j], 
                      &L[maxrank*(i-1)], &ONE_i, 
                      current, &ONE_i, 
                      RES_SLICE(i,j), &r[i] );
                
                /* update current */	
                dgemv( &no_transa, &r[i], &r[i+1], &ONE_d, 
                        V_SLICE(i,j), 
                        &r[i],   
                        current, 
                        &ONE_i, &ZERO_d, tmp, &ONE_i);
                /* ... and copy result back to current */
                dcopy( &r[i], tmp, &ONE_i, current, &ONE_i );
            }
            
            /* last core */
            daxpy( &r[1], &vals[j], current, &ONE_i, RES_SLICE(0,j), &ONE_i );
            
        }

        #pragma omp critical
        {
            /* gather all local parts into result vector */
            for( i = 0; i < d; ++i ){
                for( k = 0; k < r[i]*r[i+1]*n[i]; ++k )
                    result[i][k] += result_part[i][k];
                /*daxpy( &coresize, &ONE_d, result_part[i], &ONE_i, result[i], &ONE_i );*/
            }
            /*for( i = 0; i < d; ++i )
                coresize = *r[i]*r[i+1]*n[i];
                result[i][0] += result_part[i][0];*/
        }
#pragma omp critical
{
        mxFree( current );
        mxFree( tmp );
        mxFree( L );
    
        for( i = 0; i < d; ++i )
            mxFree( result_part[i] );
        mxFree(result_part);
    }
}
	mxFree( n );
	mxFree( r );
	mxFree( ind );
	mxFree( U );
	mxFree( V ); 
}
