#include "mex.h"
#include <math.h>
#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <blas.h>    // dgemv
#include <lapack.h>  // dstegr
#include <string.h>
#include <omp.h>
#include <iostream>
//#include <iomanip>
// #include <blas.h>    // dgemv
// #include <lapack.h>  // dstegr
using namespace std;
//#include <iomanip>
// #include <blas.h>    // dgemv
// #include <lapack.h>  // dstegr



#define  abs1(a)         ((a) < 0.0 ? -(a) : (a))
#define  sign1(a)        ((a)==0) ? 0 : (((a)>0.0)?1:(-1))
#define  max1(a,b)       ((a) > (b) ? (a) : (b))
#define  min1(a,b)       ((a) < (b) ? (a) : (b))
# define PI 3.141592653589793238463L
# define M_2PI (2 * PI)
# define eps 2.2204e-16
# define LARGE 1e100

void pvec(double *x,int n)
{
    int i;
    for(i=0;i<n;i++)
    {
        printf("%f ",x[i]);
    }
    
    printf("\n ");
    
}

 

 

//---------------------------------------------------------------------------
// solve cubic equation x^3 + a*x^2 + b*x + c
// x - array of size 3
// In case 3 real roots: => x[0], x[1], x[2], return 3
//         2 real roots: x[0], x[1],          return 2
//         1 real root : x[0], x[1] ¡À i*x[2], return 1
unsigned int solveP3(double* x, double a, double b, double c)
{
    double a2 = a * a;
    double q = (a2 - 3 * b) / 9;
    double r = (a * (2 * a2 - 9 * b) + 27 * c) / 54;
    double r2 = r * r;
    double q3 = q * q * q;
    double A, B;
    if (r2 < q3)
    {
        double t = r / sqrt(q3);
        if (t < -1) t = -1;
        if (t > 1) t = 1;
        t = acos(t);
        a /= 3; q = -2 * sqrt(q);
        x[0] = q * cos(t / 3) - a;
        x[1] = q * cos((t + M_2PI) / 3) - a;
        x[2] = q * cos((t - M_2PI) / 3) - a;
        return 3;
    }
    else
    {
        A = -pow(abs1(r) + sqrt(r2 - q3), 1. / 3);
        if (r < 0) A = -A;
        B = (0 == A ? 0 : q / A);
        
        a /= 3;
        x[0] = (A + B) - a;
        x[1] = -0.5 * (A + B) - a;
        x[2] = 0.5 * sqrt(3.) * (A - B);
        if (abs1(x[2]) < eps) { x[2] = x[1]; return 2; }
        
        return 1;
    }
}

//---------------------------------------------------------------------------
// Solve quartic equation x^4 + a*x^3 + b*x^2 + c*x + d
// (attention - this function returns dynamically allocated array. It has to be released afterwards)
int solve_quartic_2(double a, double b, double c, double d,double *x_sol)
{
    double a3 = -b;
    double b3 = a * c - 4. * d;
    double c3 = -a * a * d - c * c + 4. * b * d;
    int num = 0;
    
    // cubic resolvent
    // y^3 ? b*y^2 + (ac?4d)*y ? a^2*d?c^2+4*b*d = 0
    
    double x3[3];
    unsigned int iZeroes = solveP3(x3, a3, b3, c3);
    
    double q1, q2, p1, p2, D, sqD, y;
    
    y = x3[0];
    // THE ESSENCE - choosing Y with maximal absolute value !
    if (iZeroes != 1)
    {
        if (abs1(x3[1]) > abs1(y)) y = x3[1];
        if (abs1(x3[2]) > abs1(y)) y = x3[2];
    }
    
    // h1+h2 = y && h1*h2 = d  <=>  h^2 -y*h + d = 0    (h === q)
    
    D = y * y - 4 * d;
    if (abs1(D) < eps) //in other words - D==0
    {
        q1 = q2 = y * 0.5;
        // g1+g2 = a && g1+g2 = b-y   <=>   g^2 - a*g + b-y = 0    (p === g)
        D = a * a - 4 * (b - y);
        if (abs1(D) < eps) //in other words - D==0
            p1 = p2 = a * 0.5;
        
        else
        {
            sqD = sqrt(D);
            p1 = (a + sqD) * 0.5;
            p2 = (a - sqD) * 0.5;
        }
    }
    else
    {
        sqD = sqrt(D);
        q1 = (y + sqD) * 0.5;
        q2 = (y - sqD) * 0.5;
        // g1+g2 = a && g1*h2 + g2*h1 = c       ( && g === p )  Krammer
        p1 = (a * q1 - c) / (q1 - q2);
        p2 = (c - a * q2) / (q1 - q2);
    }
    
    //DComplex* retval = new DComplex[4];
    
    // solving quadratic eq. - x^2 + p1*x + q1 = 0
    D = p1 * p1 - 4 * q1;
    if (D < 0.0)
    {
        *x_sol = -p1 * 0.5;
        x_sol ++; num = num + 1;
//         printf("%f\n",-p1 * 0.5);
//		retval[0].real(-p1 * 0.5);
//		retval[0].imag(sqrt(-D) * 0.5);
//		retval[1] = std::conj(retval[0]);
    }
    else
    {
        sqD = sqrt(D);
        *x_sol = (-p1 + sqD) * 0.5; x_sol ++;
        *x_sol = (-p1 - sqD) * 0.5; x_sol ++;
        num = num + 2;
        
//          printf("%f\n",(-p1 + sqD) * 0.5);
//          printf("%f\n",(-p1 - sqD) * 0.5);
        
//		retval[0].real((-p1 + sqD) * 0.5);
//		retval[1].real((-p1 - sqD) * 0.5);
    }
    
    // solving quadratic eq. - x^2 + p2*x + q2 = 0
    D = p2 * p2 - 4 * q2;
    if (D < 0.0)
    {
        *x_sol = -p2 * 0.5; x_sol ++;
        num = num + 1;
//         printf("%f\n",-p2 * 0.5);
//		retval[2].real(-p2 * 0.5);
//		retval[2].imag(sqrt(-D) * 0.5);
//		retval[3] = std::conj(retval[2]);
    }
    else
    {
        sqD = sqrt(D);
//          printf("%f\n",(-p2 + sqD) * 0.5);
//          printf("%f\n",(-p2 - sqD) * 0.5);
        *x_sol = (-p2 + sqD) * 0.5;  x_sol ++;
        *x_sol = (-p2 - sqD) * 0.5;  x_sol ++;
        
        num = num + 2;
//		retval[2].real((-p2 + sqD) * 0.5);
//		retval[3].real((-p2 - sqD) * 0.5);
    }
    
    
    return num;
}








int solve_quartic(double a, double b, double c, double d,double e, double *x_sol)
{
// Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
    int num=0;
    if(abs1(a)<eps)
    {
        // sover b*x^3 + c*x^2 + d*x + e = 0
        if(abs1(b)<eps)
        {
            // c*x^2 + d*x + e = 0
            if(abs1(c)<eps)
            {
                if(abs1(d)<eps)
                {
                    if(abs1(e)<eps)
                    {
                        x_sol[0] = 1;
                        num = 1;
                    }
                    else
                    {
                        num = 0;
                    }
                }
                else
                {
                    x_sol[0] = - e/d;
                    num = 1;
                }
            }
            else
            {
                // c*x^2 + d*x + e = 0
                double delta = d*d - 4*c*e;
                if(delta<0)
                {
                    num=0;
                }
                else
                {
                    x_sol[0] =  (- d + sqrt(delta)) / (2*c);
                    x_sol[1] =  (- d - sqrt(delta)) / (2*c);
                    num = 2;
                }
            }
            
        }
        else
        {
            // sover x^3 + c/b*x^2 + d/b*x + e/b = 0
            num = solveP3(x_sol, c/b, d/b, e/b);
        }
    }
    else
    {
        // Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
        num =  solve_quartic_2(b/a,c/a,d/a,e/a,x_sol);
    }
    return num;
}


 

double ComputeObj_cs(double c,double s, double A,double B,double C,double D, double E, double lambda, double*x,double*y,int n,double myeps)
{
// HandleObj_cs = @(c,s)A*c + B*s + C*c*c + D*c*s + E*s*s + lambda*||c*x + s*y||_0;
    double val = 0;int i;
    for(i=0;i<n;i++)
    {
        double tmp = abs1(c*x[i] + s*y[i]);
        if(tmp>myeps) val +=1;
        
    }
    val = val * lambda + A*c + B*s + C*c*c + D*c*s + E*s*s;
    return val;
    
}



int get_all_critical_points(const double a,const double b,const double c,const double d,const double e, double*x_sol)
{
// This program get the critical points of the following unconstrainted problem:
//          +/-  a  +/-  bt           c - e + d*t
// min_t  -------------------   +   -----------------
//            sqrt(1+t*t)               1+t*t
    
    double w = c -e;
    double c0 = d*d - b*b;
    double c1 = 2*b*a - 4*d*w;
    double c2 = 4*w*w - 2*d*d -a*a - b*b;
    double c3 = 4*w*d + 2*a*b;
    double c4 = d*d - a*a;
    int num=solve_quartic(c4,c3,c2,c1,c0,x_sol);
    return num;
}



void nonconvex_quadratic_trigonometry_L0_cc(double myeps, double A,double B, double C,double D,double E,double lambda,double *x, double*y,int R, double *best_c, double*best_s,double*best_fobj)
{
// min_{c,s} A*c + B*s + C*c*c + D*c*s + E*s*s + lambda*||c*x+s*y||_0, s.t. c*c + s*s = 1
    
    int i; int j; int k;
    double tmp_c; double tmp_s; double tmp_fobj;
    
 *best_c = 1; *best_s = 0; *best_fobj = LARGE;
 tmp_c = 1; tmp_s = 0;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
 tmp_c =-1; tmp_s = 0;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
 tmp_c = 0; tmp_s = 1;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
 tmp_c = 0; tmp_s = -1;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}

 double x_sol[4];
 int num = get_all_critical_points(A,B,C,D,E,x_sol);
 
 for(i=0;i<num;i++)
{
     double ttt = x_sol[i];
     double chi = 1/sqrt(1+ttt*ttt);
     tmp_c = chi;  tmp_s =   ttt*chi;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
     tmp_c = -chi; tmp_s =  -ttt*chi;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}

 }


 for(i=0;i<R;i++)
 {
     if(abs1(y[i])>eps)
     {
         // case 1
        double ttt = x[i] / y[i];
        double chi = 1 / sqrt(1+ttt*ttt);
        tmp_c =  chi; tmp_s =   ttt*chi;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        tmp_c =  chi; tmp_s =  -ttt*chi;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        tmp_c = -chi; tmp_s =  -ttt*chi;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        tmp_c = -chi; tmp_s =   ttt*chi;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R,myeps); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}

     }
     
 }

 

}




void nonconvex_orth2d_quad_L0_cc(double myeps, const double *Q,const double*P,const double lambda,const double*Z1,const double*Z2, int r,double*V)
{
// HandleObj = @(V)0.5*vec(V)'*Q*vec(V) + mdot(V,P) + lambda* L1Norm(V*Z);
// min_{V} HandleObj(V), s.t. V'V = I
    
    // delete some Z
    int i;int j;
    double cof_a; double cof_b;double cof_c;double cof_d;double cof_e;
    double best_c1; double best_s1; double best_f1;
    double best_c2; double best_s2; double best_f2;
    
    double *Z1_true = (double *)malloc(sizeof(double)*r);
    double *Z2_true = (double *)malloc(sizeof(double)*r);
    int N_true=0;
    for(i=0;i<r;i++)
    {
        if(  (abs1(Z1[i])>eps)   ||   (abs1(Z2[i])>eps)  )
        {
            Z1_true[N_true] = Z1[i];
            Z2_true[N_true] = Z2[i];
            N_true++;
        }
    }
    
    double *cof_x = (double *)malloc(sizeof(double)*N_true*2);
    double *cof_y = (double *)malloc(sizeof(double)*N_true*2);
    
    // Case1: rotation matrix. Compute V1 and its objective function, V = [c s; -s c];
    for(i=0;i<N_true;i++){cof_x[i] = Z1_true[i]; cof_x[i+N_true] = Z2_true[i]; cof_y[i] = Z2_true[i]; cof_y[i+N_true] = -Z1_true[i];}
    cof_a = P[0] + P[3]; cof_b = P[2] - P[1]; cof_c =  0.5*Q[0] + Q[12] + 0.5*Q[15]; cof_d =  - Q[4] + Q[8] -Q[13] + Q[14]; cof_e =  0.5*Q[5] - Q[9] + 0.5*Q[10];
    nonconvex_quadratic_trigonometry_L0_cc(myeps,cof_a,cof_b,cof_c,cof_d,cof_e,lambda,cof_x, cof_y, N_true*2, &best_c1,&best_s1,&best_f1);
    
    // Case2: reflection matrix. Compute V2 and its objective function, V = [ -c s; s c];
    for(i=0;i<N_true;i++){cof_x[i] = -Z1_true[i]; cof_x[i+N_true] = Z2_true[i]; cof_y[i] = Z2_true[i];  cof_y[i+N_true] = Z1_true[i];}
    cof_a  = -P[0] + P[3]; cof_b  =  P[1] + P[2]; cof_c = 0.5*Q[0] - Q[12] + 0.5*Q[15]; cof_d = - Q[4] - Q[8] +  Q[13] + Q[14]; cof_e = 0.5*Q[5] + Q[9] + 0.5*Q[10];
    nonconvex_quadratic_trigonometry_L0_cc(myeps,cof_a,cof_b,cof_c,cof_d,cof_e,lambda,cof_x, cof_y, N_true*2, &best_c2,&best_s2,&best_f2);
    
    if(best_f1<best_f2){V[0] = best_c1; V[1] = -best_s1; V[2] = best_s1; V[3] = best_c1;}
    else {V[0] = -best_c2; V[1] = best_s2; V[2] = best_s2; V[3] = best_c2;}
       
    free(Z1_true);free(Z2_true); free(cof_x); free(cof_y);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    
    double *Q       =  mxGetPr(prhs[0]);
    double *P       =  mxGetPr(prhs[1]);
    double lambda   =  mxGetScalar(prhs[2]);
    double *Z1      =  mxGetPr(prhs[3]);
    double *Z2      =  mxGetPr(prhs[4]);
    int r           =  (int)mxGetScalar(prhs[5]);
    double myeps    =  mxGetScalar(prhs[6]);

    plhs[0] = mxCreateDoubleMatrix(2,2,mxREAL);
    double *V = mxGetPr(plhs[0]);
    nonconvex_orth2d_quad_L0_cc(myeps,Q,P,lambda,Z1,Z2,r,V);
//    nonconvex_quadratic_trigonometry_L1(A,B,C,D,E,lambda,x,y,R,best_c,best_s);
    
//    solve_quartic(c1/c0,c2/c0,c3/c0,c4/c0,out1,out2);
    
}
