#include "mex.h"
#include <math.h>
#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <blas.h>    // dgemv
#include <lapack.h>  // dstegr
#include <string.h>
#include <omp.h>
#include <iostream>
//#include <iomanip>
using namespace std;



#define  abs1(a)         ((a) < 0.0 ? -(a) : (a))
#define  sign1(a)        ((a)==0) ? 0 : (((a)>0.0)?1:(-1))
#define  max1(a,b)       ((a) > (b) ? (a) : (b))
#define  min1(a,b)       ((a) < (b) ? (a) : (b))
# define PI 3.141592653589793238463L
# define M_2PI (2 * PI)
# define eps 2.2204e-14
# define LARGE 1e100

void pvec(double *x,int n)
{
    int i;
    for(i=0;i<n;i++)
    {
        printf("%f ",x[i]);
    }
    
    printf("\n ");
    
}





//---------------------------------------------------------------------------
// solve cubic equation x^3 + a*x^2 + b*x + c
// x - array of size 3
// In case 3 real roots: => x[0], x[1], x[2], return 3
//         2 real roots: x[0], x[1],          return 2
//         1 real root : x[0], x[1] ¡À i*x[2], return 1
unsigned int solveP3(double* x, double a, double b, double c)
{
    double a2 = a * a;
    double q = (a2 - 3 * b) / 9;
    double r = (a * (2 * a2 - 9 * b) + 27 * c) / 54;
    double r2 = r * r;
    double q3 = q * q * q;
    double A, B;
    if (r2 < q3)
    {
        double t = r / sqrt(q3);
        if (t < -1) t = -1;
        if (t > 1) t = 1;
        t = acos(t);
        a /= 3; q = -2 * sqrt(q);
        x[0] = q * cos(t / 3) - a;
        x[1] = q * cos((t + M_2PI) / 3) - a;
        x[2] = q * cos((t - M_2PI) / 3) - a;
        return 3;
    }
    else
    {
        A = -pow(abs1(r) + sqrt(r2 - q3), 1. / 3);
        if (r < 0) A = -A;
        B = (0 == A ? 0 : q / A);
        
        a /= 3;
        x[0] = (A + B) - a;
        x[1] = -0.5 * (A + B) - a;
        x[2] = 0.5 * sqrt(3.) * (A - B);
        if (abs1(x[2]) < eps) { x[2] = x[1]; return 2; }
        
        return 1;
    }
}

//---------------------------------------------------------------------------
// Solve quartic equation x^4 + a*x^3 + b*x^2 + c*x + d
// (attention - this function returns dynamically allocated array. It has to be released afterwards)
int solve_quartic_2(double a, double b, double c, double d,double *x_sol)
{
    double a3 = -b;
    double b3 = a * c - 4. * d;
    double c3 = -a * a * d - c * c + 4. * b * d;
    int num = 0;
    
    double x3[3];
    unsigned int iZeroes = solveP3(x3, a3, b3, c3);
    
    double q1, q2, p1, p2, D, sqD, y;
    
    y = x3[0];
    if (iZeroes != 1)
    {
        if (abs1(x3[1]) > abs1(y)) y = x3[1];
        if (abs1(x3[2]) > abs1(y)) y = x3[2];
    }
    
    // h1+h2 = y && h1*h2 = d  <=>  h^2 -y*h + d = 0    (h === q)
    
    D = y * y - 4 * d;
    if (abs1(D) < eps) //in other words - D==0
    {
        q1 = q2 = y * 0.5;
        // g1+g2 = a && g1+g2 = b-y   <=>   g^2 - a*g + b-y = 0    (p === g)
        D = a * a - 4 * (b - y);
        if (abs1(D) < eps) //in other words - D==0
            p1 = p2 = a * 0.5;
        
        else
        {
            sqD = sqrt(D);
            p1 = (a + sqD) * 0.5;
            p2 = (a - sqD) * 0.5;
        }
    }
    else
    {
        sqD = sqrt(D);
        q1 = (y + sqD) * 0.5;
        q2 = (y - sqD) * 0.5;
        // g1+g2 = a && g1*h2 + g2*h1 = c  
        p1 = (a * q1 - c) / (q1 - q2);
        p2 = (c - a * q2) / (q1 - q2);
    }
    
    // solving quadratic eq. - x^2 + p1*x + q1 = 0
    D = p1 * p1 - 4 * q1;
    if (D < 0.0)
    {
        *x_sol = -p1 * 0.5;
        x_sol ++; num = num + 1;
//         printf("%f\n",-p1 * 0.5);
//		retval[0].real(-p1 * 0.5);
//		retval[0].imag(sqrt(-D) * 0.5);
//		retval[1] = std::conj(retval[0]);
    }
    else
    {
        sqD = sqrt(D);
        *x_sol = (-p1 + sqD) * 0.5; x_sol ++;
        *x_sol = (-p1 - sqD) * 0.5; x_sol ++;
        num = num + 2;
        
//          printf("%f\n",(-p1 + sqD) * 0.5);
//          printf("%f\n",(-p1 - sqD) * 0.5);
        
//		retval[0].real((-p1 + sqD) * 0.5);
//		retval[1].real((-p1 - sqD) * 0.5);
    }
    
    // solving quadratic eq. - x^2 + p2*x + q2 = 0
    D = p2 * p2 - 4 * q2;
    if (D < 0.0)
    {
        *x_sol = -p2 * 0.5; x_sol ++;
        num = num + 1;
//         printf("%f\n",-p2 * 0.5);
//		retval[2].real(-p2 * 0.5);
//		retval[2].imag(sqrt(-D) * 0.5);
//		retval[3] = std::conj(retval[2]);
    }
    else
    {
        sqD = sqrt(D);
//          printf("%f\n",(-p2 + sqD) * 0.5);
//          printf("%f\n",(-p2 - sqD) * 0.5);
        *x_sol = (-p2 + sqD) * 0.5;  x_sol ++;
        *x_sol = (-p2 - sqD) * 0.5;  x_sol ++;
        
        num = num + 2;
//		retval[2].real((-p2 + sqD) * 0.5);
//		retval[3].real((-p2 - sqD) * 0.5);
    }
    
    
    return num;
}










int solve_quartic(double a, double b, double c, double d,double e, double *x_sol)
{
// Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
    int num=0;
    if(abs1(a)<eps) 
    {
        // a==0, b*x^3 + c*x^2 + d*x + e = 0
        if(abs1(b)<eps)
        {
            // b == 0, c*x^2 + d*x + e = 0
            if(abs1(c)<eps)
            {
                // c =0, d*x + e = 0
                if(abs1(d)<eps)
                {
                    if(abs1(e)<eps)
                    {
                        x_sol[0] = 1;
                        num = 1;
                    }
                    else
                    {
                        num = 0;
                    }
                }
                else
                {
                    x_sol[0] = - e/d;
                    num = 1;
                }
            }
            else
            {
                // c ! =0, c*x^2 + d*x + e = 0
                double delta = d*d - 4*c*e;
                if(delta<0)
                {
                    num=0;
                }
                else
                {
                    x_sol[0] =  (- d + sqrt(delta)) / (2*c);
                    x_sol[1] =  (- d - sqrt(delta)) / (2*c);
                    num = 2;
                }
            }
            
        }
        else
        {
            // sover x^3 + c/b*x^2 + d/b*x + e/b = 0
            num = solveP3(x_sol, c/b, d/b, e/b);
        }
    }
    else
    {
        // Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
        num =  solve_quartic_2(b/a,c/a,d/a,e/a,x_sol);
    }
    return num;
}


int get_all_critical_points(const double a,const double b,const double c,const double d,const double e, double*x_sol)
{
// This program get the critical points of the following unconstrainted problem:
//          +/-  a  +/-  bt           c - e + d*t
// min_t  -------------------   +   -----------------
//            sqrt(1+t*t)               1+t*t
    
    double w = c -e;
    double c0 = d*d - b*b;
    double c1 = 2*b*a - 4*d*w;
    double c2 = 4*w*w - 2*d*d -a*a - b*b;
    double c3 = 4*w*d + 2*a*b;
    double c4 = d*d - a*a;
    int num=solve_quartic(c4,c3,c2,c1,c0,x_sol);
    return num;
}

double ComputeObj_cs(double c,double s, double A,double B,double C,double D, double E,double*x,double*y,int n)
{
// HandleObj_cs = @(c,s)A*c + B*s + C*c*c + D*c*s + E*s*s, s.t. c*x + s*y>=0
    double val = 0;int i;
    for(i=0;i<n;i++)
    {
        if( (c*x[i] + s*y[i]) <-eps )
        {
            return LARGE;
        }
    }
    val = A*c + B*s + C*c*c + D*c*s + E*s*s;
    return val;
    
}


void nonconvex_quadratic_trigonometric_nonnegative_cc(int flag, double A,double B, double C,double D,double E,double *x, double*y,int R, double *best_c, double*best_s,double*best_fobj)
{
// min_{c,s} A*c + B*s + C*c*c + D*c*s + E*s*s, s.t. c*x+s*y>=0, c*c + s*s = 1
    
    int i; int j; int k;
    double x_sol[4];
    int num_sol;
    double tmp_fobj = 0; double tmp_c; double tmp_s;
    
    num_sol =  get_all_critical_points(A,B,C,D,E,x_sol);
    
    double ub1 = LARGE; double ub2 = LARGE; double lb1 = -LARGE; double lb2 = -LARGE;
    
    
    for (i=0;i<R;i++)
    {
      //if(abs1(y[i])<eps) continue;
        if(y[i]>0)
        {
            // index I
            double tmp = -x[i] /y[i];
            lb1 = max1(lb1,tmp);
            ub2 = min1(ub2,tmp);
            
        }
        else if(y[i]<0)
        {
            // index J
            double tmp = -x[i]/y[i];
            lb2 = max1(lb2,tmp);
            ub1 = min1(ub1,tmp);
        }
        
    }
    
    
// compute the objective for every candidate solutions Very critical!
    *best_fobj = LARGE; *best_c = 1; *best_s = 0;
    tmp_c = 1; tmp_s = 0;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj<*best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    tmp_c = 0; tmp_s = 1;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj<*best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    tmp_c = 0; tmp_s = -1;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj<*best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    tmp_c = -1; tmp_s = 0;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj<*best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    
    
    
// Case 1
    if(ub1>=lb1)
    {
        
        if(ub1<LARGE)
        {
            double t_case_1 = ub1; double chi = 1/sqrt(1+t_case_1*t_case_1); double cc = chi; double ss = t_case_1*chi;
            tmp_c = cc; tmp_s = ss;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        }
        if(lb1>-LARGE)
        {
            double t_case_1 = lb1; double chi = 1/sqrt(1+t_case_1*t_case_1); double cc = chi; double ss = t_case_1*chi;
            tmp_c = cc; tmp_s = ss;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        }
        for (i=0;i<num_sol;i++)
        {
            double t_case_1 = x_sol[i]; t_case_1 = min1(t_case_1,ub1); t_case_1 = max1(t_case_1,lb1);
            double chi = 1/sqrt(1+t_case_1*t_case_1); double cc = chi; double ss = t_case_1*chi;
            tmp_c = cc; tmp_s = ss;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        }
    }
    // Case 2
    if(ub2>lb2)
    {
        
        if(lb2>-LARGE)
        {
            double t_case_2 = lb2; double chi = 1/sqrt(1+t_case_2*t_case_2); double cc = -chi; double ss = -t_case_2*chi;
            tmp_c = cc; tmp_s = ss;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        }
        
        if(ub2<LARGE)
        {
            double t_case_2 = ub2; double chi = 1/sqrt(1+t_case_2*t_case_2); double cc = -chi; double ss = -t_case_2*chi;
            tmp_c = cc; tmp_s = ss;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
            
        }
        for (i=0;i<num_sol;i++)
        {
            double t_case_2 = x_sol[i]; t_case_2 = min1(t_case_2,ub2); t_case_2 = max1(t_case_2,lb2);
            double chi = 1/sqrt(1+t_case_2*t_case_2); double cc = -chi; double ss = -t_case_2*chi;
            tmp_c = cc; tmp_s = ss;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
            
        }
    }
    
    if(flag==1 && *best_fobj >= LARGE)
    {
        *best_c = 1; *best_s = 0;
        *best_fobj = -LARGE;
    }
}


void nonconvex_orth2d_quad_nonnegative_cc(const double *Q,const double*P,const double*Z1,const double*Z2, int r,double*V)
{
// HandleObj = @(V)0.5*vec(V)'*Q*vec(V) + mdot(V,P) + lambda* L1Norm(V*Z);
// min_{V} HandleObj(V), s.t. V'V = I
    
    int i;int j;
    double cof_a; double cof_b;double cof_c;double cof_d;double cof_e;
    double best_c1; double best_s1; double best_f1;
    double best_c2; double best_s2; double best_f2;
    
    double *Z1_true = (double *)malloc(sizeof(double)*r);
    double *Z2_true = (double *)malloc(sizeof(double)*r);
    int N_true=0;
    for(i=0;i<r;i++)
    {
        if(  (abs1(Z1[i])>1e-14)   ||   (abs1(Z2[i])>1e-14)  )
        {
            Z1_true[N_true] = Z1[i];
            Z2_true[N_true] = Z2[i];
            N_true++;
        }
    }
    
    double *cof_x = (double *)malloc(sizeof(double)*N_true*2);
    double *cof_y = (double *)malloc(sizeof(double)*N_true*2);
    
    // Case1: rotation matrix. Compute V1 and its objective function, V = [c s; -s c];
    for(i=0;i<N_true;i++){cof_x[i] = Z1_true[i]; cof_x[i+N_true] = Z2_true[i]; cof_y[i] = Z2_true[i]; cof_y[i+N_true] = -Z1_true[i];}
    cof_a = P[0] + P[3]; cof_b = P[2] - P[1]; cof_c =  0.5*Q[0] + Q[12] + 0.5*Q[15]; cof_d =  - Q[4] + Q[8] - Q[13] + Q[14]; cof_e =  0.5*Q[5] - Q[9] + 0.5*Q[10];
    
    nonconvex_quadratic_trigonometric_nonnegative_cc(1,cof_a,cof_b,cof_c,cof_d,cof_e,cof_x, cof_y, N_true*2, &best_c1,&best_s1,&best_f1);
    
    // Case2: reflection matrix. Compute V2 and its objective function, V = [ -c s; s c];
    for(i=0;i<N_true;i++){cof_x[i] = -Z1_true[i]; cof_x[i+N_true] = Z2_true[i]; cof_y[i] = Z2_true[i];  cof_y[i+N_true] = Z1_true[i];}
    cof_a  = -P[0] + P[3]; cof_b  =  P[1] + P[2]; cof_c = 0.5*Q[0] - Q[12] + 0.5*Q[15]; cof_d = - Q[4] - Q[8] +  Q[13] + Q[14]; cof_e = 0.5*Q[5] + Q[9] + 0.5*Q[10];
    nonconvex_quadratic_trigonometric_nonnegative_cc(0,cof_a,cof_b,cof_c,cof_d,cof_e,cof_x, cof_y, N_true*2, &best_c2,&best_s2,&best_f2);

 
    if(best_f1<=best_f2){V[0] = best_c1; V[1] = -best_s1; V[2] = best_s1; V[3] = best_c1;}
                   else {V[0] = -best_c2; V[1] = best_s2; V[2] = best_s2; V[3] = best_c2;}
    
    free(Z1_true);free(Z2_true); free(cof_x); free(cof_y);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    
    double *Q       =  mxGetPr(prhs[0]);
    double *P       =  mxGetPr(prhs[1]);
    double *Z1      =  mxGetPr(prhs[2]);
    double *Z2      =  mxGetPr(prhs[3]);
    int r           =  (int)mxGetScalar(prhs[4]);
    
    plhs[0] = mxCreateDoubleMatrix(2,2,mxREAL);
    double *V = mxGetPr(plhs[0]);
    nonconvex_orth2d_quad_nonnegative_cc(Q,P,Z1,Z2,r,V);
    
 
    
}
