#include "mex.h"
#include <math.h>
#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <blas.h>    // dgemv
#include <lapack.h>  // dstegr
#include <string.h>
#include <omp.h>
#include <iostream>
//#include <iomanip>
// #include <blas.h>    // dgemv
// #include <lapack.h>  // dstegr
using namespace std;
//#include <iomanip>
// #include <blas.h>    // dgemv
// #include <lapack.h>  // dstegr



#define  abs1(a)         ((a) < 0.0 ? -(a) : (a))
#define  sign1(a)        ((a)==0) ? 0 : (((a)>0.0)?1:(-1))
#define  max1(a,b)       ((a) > (b) ? (a) : (b))
#define  min1(a,b)       ((a) < (b) ? (a) : (b))
# define PI 3.141592653589793238463L
# define M_2PI (2 * PI)
# define eps 2.2204e-16
# define LARGE 1e100

void pvec(double *x,int n)
{
    int i;
    for(i=0;i<n;i++)
    {
        printf("%f ",x[i]);
    }
    
    printf("\n ");
    
}

void CreatHeap(double a[], int i, int  n)
{
    for (; i >= 0; --i)
    {
        int left = i * 2 + 1; int right = i * 2 + 2; int j = 0;
        if (right < n)
        {
            if(a[left] > a[right])
                j = left;
            else
                j = right;
        }
        else
            j = left;
        if (a[j] > a[i])
        {
            double tmp = a[i];
            a[i] = a[j];
            a[j] = tmp;
        }
    }
}

void HeapSort(double a[], int n)
{
    // In increasing order
    CreatHeap(a, n/2-1, n);
    for (int j = n-1; j >= 0; j--)
    {
        
        double tmp = a[0];
        a[0] = a[j];
        a[j] = tmp;
        
        int i = j / 2 - 1;
        CreatHeap(a, i, j);
    }
}


//---------------------------------------------------------------------------
// solve cubic equation x^3 + a*x^2 + b*x + c
// x - array of size 3
// In case 3 real roots: => x[0], x[1], x[2], return 3
//         2 real roots: x[0], x[1],          return 2
//         1 real root : x[0], x[1] ¡À i*x[2], return 1
unsigned int solveP3(double* x, double a, double b, double c)
{
    double a2 = a * a;
    double q = (a2 - 3 * b) / 9;
    double r = (a * (2 * a2 - 9 * b) + 27 * c) / 54;
    double r2 = r * r;
    double q3 = q * q * q;
    double A, B;
    if (r2 < q3)
    {
        double t = r / sqrt(q3);
        if (t < -1) t = -1;
        if (t > 1) t = 1;
        t = acos(t);
        a /= 3; q = -2 * sqrt(q);
        x[0] = q * cos(t / 3) - a;
        x[1] = q * cos((t + M_2PI) / 3) - a;
        x[2] = q * cos((t - M_2PI) / 3) - a;
        return 3;
    }
    else
    {
        A = -pow(abs1(r) + sqrt(r2 - q3), 1. / 3);
        if (r < 0) A = -A;
        B = (0 == A ? 0 : q / A);
        
        a /= 3;
        x[0] = (A + B) - a;
        x[1] = -0.5 * (A + B) - a;
        x[2] = 0.5 * sqrt(3.) * (A - B);
        if (abs1(x[2]) < eps) { x[2] = x[1]; return 2; }
        
        return 1;
    }
}

//---------------------------------------------------------------------------
// Solve quartic equation x^4 + a*x^3 + b*x^2 + c*x + d
// (attention - this function returns dynamically allocated array. It has to be released afterwards)
int solve_quartic_2(double a, double b, double c, double d,double *x_sol)
{
    double a3 = -b;
    double b3 = a * c - 4. * d;
    double c3 = -a * a * d - c * c + 4. * b * d;
    int num = 0;
    
    // cubic resolvent
    // y^3 ? b*y^2 + (ac?4d)*y ? a^2*d?c^2+4*b*d = 0
    
    double x3[3];
    unsigned int iZeroes = solveP3(x3, a3, b3, c3);
    
    double q1, q2, p1, p2, D, sqD, y;
    
    y = x3[0];
    // THE ESSENCE - choosing Y with maximal absolute value !
    if (iZeroes != 1)
    {
        if (abs1(x3[1]) > abs1(y)) y = x3[1];
        if (abs1(x3[2]) > abs1(y)) y = x3[2];
    }
    
    // h1+h2 = y && h1*h2 = d  <=>  h^2 -y*h + d = 0    (h === q)
    
    D = y * y - 4 * d;
    if (abs1(D) < eps) //in other words - D==0
    {
        q1 = q2 = y * 0.5;
        // g1+g2 = a && g1+g2 = b-y   <=>   g^2 - a*g + b-y = 0    (p === g)
        D = a * a - 4 * (b - y);
        if (abs1(D) < eps) //in other words - D==0
            p1 = p2 = a * 0.5;
        
        else
        {
            sqD = sqrt(D);
            p1 = (a + sqD) * 0.5;
            p2 = (a - sqD) * 0.5;
        }
    }
    else
    {
        sqD = sqrt(D);
        q1 = (y + sqD) * 0.5;
        q2 = (y - sqD) * 0.5;
        // g1+g2 = a && g1*h2 + g2*h1 = c       ( && g === p )  Krammer
        p1 = (a * q1 - c) / (q1 - q2);
        p2 = (c - a * q2) / (q1 - q2);
    }
    
    //DComplex* retval = new DComplex[4];
    
    // solving quadratic eq. - x^2 + p1*x + q1 = 0
    D = p1 * p1 - 4 * q1;
    if (D < 0.0)
    {
        *x_sol = -p1 * 0.5;
        x_sol ++; num = num + 1;
//         printf("%f\n",-p1 * 0.5);
//		retval[0].real(-p1 * 0.5);
//		retval[0].imag(sqrt(-D) * 0.5);
//		retval[1] = std::conj(retval[0]);
    }
    else
    {
        sqD = sqrt(D);
        *x_sol = (-p1 + sqD) * 0.5; x_sol ++;
        *x_sol = (-p1 - sqD) * 0.5; x_sol ++;
        num = num + 2;
        
//          printf("%f\n",(-p1 + sqD) * 0.5);
//          printf("%f\n",(-p1 - sqD) * 0.5);
        
//		retval[0].real((-p1 + sqD) * 0.5);
//		retval[1].real((-p1 - sqD) * 0.5);
    }
    
    // solving quadratic eq. - x^2 + p2*x + q2 = 0
    D = p2 * p2 - 4 * q2;
    if (D < 0.0)
    {
        *x_sol = -p2 * 0.5; x_sol ++;
        num = num + 1;
//         printf("%f\n",-p2 * 0.5);
//		retval[2].real(-p2 * 0.5);
//		retval[2].imag(sqrt(-D) * 0.5);
//		retval[3] = std::conj(retval[2]);
    }
    else
    {
        sqD = sqrt(D);
//          printf("%f\n",(-p2 + sqD) * 0.5);
//          printf("%f\n",(-p2 - sqD) * 0.5);
        *x_sol = (-p2 + sqD) * 0.5;  x_sol ++;
        *x_sol = (-p2 - sqD) * 0.5;  x_sol ++;
        
        num = num + 2;
//		retval[2].real((-p2 + sqD) * 0.5);
//		retval[3].real((-p2 - sqD) * 0.5);
    }
    
    
    return num;
}








int solve_quartic(double a, double b, double c, double d,double e, double *x_sol)
{
// Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
    int num=0;
    if(abs1(a)<eps)
    {
        // sover b*x^3 + c*x^2 + d*x + e = 0
        if(abs1(b)<eps)
        {
            // c*x^2 + d*x + e = 0
            if(abs1(c)<eps)
            {
                if(abs1(d)<eps)
                {
                    if(abs1(e)<eps)
                    {
                        x_sol[0] = 1;
                        num = 1;
                    }
                    else
                    {
                        num = 0;
                    }
                }
                else
                {
                    x_sol[0] = - e/d;
                    num = 1;
                }
            }
            else
            {
                // c*x^2 + d*x + e = 0
                double delta = d*d - 4*c*e;
                if(delta<0)
                {
                    num=0;
                }
                else
                {
                    x_sol[0] =  (- d + sqrt(delta)) / (2*c);
                    x_sol[1] =  (- d - sqrt(delta)) / (2*c);
                    num = 2;
                }
            }
            
        }
        else
        {
            // sover x^3 + c/b*x^2 + d/b*x + e/b = 0
            num = solveP3(x_sol, c/b, d/b, e/b);
        }
    }
    else
    {
        // Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
        num =  solve_quartic_2(b/a,c/a,d/a,e/a,x_sol);
    }
    return num;
}


double ComputeObj_cs(const double c,const double s,const double A,const double B,const double C,const double D,const double E,const double lambda,const double*x,const double*y,const int n)
{
// HandleObj_cs = @(c,s)A*c + B*s + C*c*c + D*c*s + E*s*s + lambda*norm(c*x + s*y,1);
    double val = 0;int i;
    for(i=0;i<n;i++)
    {
        val += abs1(c*x[i] + s*y[i]);
    }
    val = val * lambda + A*c + B*s + C*c*c + D*c*s + E*s*s;
    return val;
    
}


void nonconvex_quadratic_trigonometry_L1(const double A,const double B,const double C,const double D, const double E,const double lambda,const double *x,const double*y,const int R, double *best_c, double*best_s,double*best_fobj)
{
// min_{c,s} A*c + B*s + C*c*c + D*c*s + E*s*s + lambda*norm(c*x+s*y,1), s.t. c*c + s*s = 1
    
    int i; int j; int k;
    double W = C-E;
    double x_sol[4];
    int num;
    double tmp_fobj = 0; double tmp_c; double tmp_s;
    double *list1 = (double *)malloc(sizeof(double)*R);
    double *list1_sorted = (double *)malloc(sizeof(double)*R); int II = 0;
    double *list2 = (double *)malloc(sizeof(double)*(2*R+2)); // = 2*II + 2
//     double *c_list = (double *)malloc(sizeof(double)*(4+R*2+(2*R+2)*2*4)); int JJ = 0;
//     double *s_list = (double *)malloc(sizeof(double)*(4+R*2+(2*R+2)*2*4));
    
    for(i=0;i<R;i++)
    {
        double tmp = - x[i] / y[i];
        if(tmp<= LARGE && tmp>=-LARGE)
        {
            list1[II] = tmp;
            list1_sorted[II] = abs1(tmp);
            II++;
        }
    }
    
    HeapSort(list1_sorted,II);
    list2[1-1] = -LARGE;
    list2[2*II+2-1] = LARGE;
    for (i=0;i<II;i++)
    {
        list2[i+1] = - list1_sorted[II-i-1];
        list2[i+II+1] = list1_sorted[i];
    }
    
    // compute the objective for every candidate solutions
    *best_fobj = LARGE; *best_c = -1; *best_s = -1;
    tmp_c = 0; tmp_s = 1;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    tmp_c = 1; tmp_s = 0;   tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    tmp_c = 0; tmp_s = -1;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    tmp_c = -1; tmp_s = 0;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj<*best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    
// handle list 1
    for (i=0;i<II;i++)
    {
        double tt = list1[i];
        double cc = 1 / sqrt(tt*tt+1);
        double ss = tt*cc;
        tmp_c = cc; tmp_s = ss;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
        cc = - 1 / sqrt(tt*tt+1);
        ss = tt*cc;
        tmp_c = cc; tmp_s = ss;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
    }
    
    for (i=0;i<(2*II+1);i++)
    {
        double t = 0.5*(list2[i] + list2[i+1]);
        double oy = 0;
        double ox = 0;
        for(j=0;j<R;j++)
        {
            double o = sign1(x[j] + t*y[j]);
            oy += o*y[j];
            ox += o*x[j];
        }
        
        double c00 = - lambda*oy - B;
        double c11 = A  + lambda*ox;
        double c0 = D;
        double c1 = - W*2;
        double c2 = -D ;
        double cof0 = c00*c00 - c0*c0;
        double cof1 = 2*c00*c11 - 2*c0*c1;
        double cof2 = c11*c11  + c00*c00 - 2*c0*c2 - c1*c1;
        double cof3 = 2*c00*c11 - 2*c2*c1;
        double cof4 = c11*c11 - c2*c2;
        
        
        num=solve_quartic(cof4,cof3,cof2,cof1,cof0,x_sol);
        for(k=0;k<num;k++)
        {
            double ttt = x_sol[k];
            if(ttt >= list2[i] && ttt <= list2[i+1] )
            {
                double cc = 1   / sqrt(1+ttt*ttt);
                double ss = ttt*cc;;
                tmp_c = cc; tmp_s = ss;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
            }
        }
        
        oy = -oy;
        ox = -ox;
        c00 =   lambda*oy + B;
        c11 = -A - lambda*ox;
        c0 = D;
        c1 = - W*2;
        c2 = -D ;
        cof0 = c00*c00 - c0*c0;
        cof1 = 2*c00*c11  - 2*c0*c1;
        cof2 = c11*c11 + c00*c00 - 2*c0*c2 - c1*c1;
        cof3 = 2*c00*c11 - 2*c2*c1;
        cof4 = c11*c11 - c2*c2;
        
        num=solve_quartic(cof4,cof3,cof2,cof1,cof0,x_sol);
        for(k=0;k<num;k++)
        {
            double ttt = x_sol[k];
            if(ttt >= list2[i] && ttt <= list2[i+1] )
            {
                double cc = -1 / sqrt(1+ttt*ttt);
                double ss = ttt*cc;
                tmp_c = cc; tmp_s = ss;  tmp_fobj = ComputeObj_cs(tmp_c,tmp_s,A,B,C,D,E,lambda,x,y,R); if(tmp_fobj< *best_fobj){ *best_fobj = tmp_fobj;*best_c = tmp_c; *best_s = tmp_s;}
            }
        }
        
        
    }
    
    
    
    free(list1);
    free(list1_sorted);
    free(list2);
    
    
}

void nonconvex_orth2d_quad_L1(const double *Q,const double*P,const double lambda,const double*Z1,const double*Z2, int r,double*V)
{
// HandleObj = @(V)0.5*vec(V)'*Q*vec(V) + mdot(V,P) + lambda* L1Norm(V*Z);
// min_{V} HandleObj(V), s.t. V'V = I
    
    // delete some Z
    int i;int j;
    double cof_a; double cof_b;double cof_c;double cof_d;double cof_e;
    double best_c1; double best_s1; double best_f1;
    double best_c2; double best_s2; double best_f2;
    
    double *Z1_true = (double *)malloc(sizeof(double)*r);
    double *Z2_true = (double *)malloc(sizeof(double)*r);
    int N_true=0;
    for(i=0;i<r;i++)
    {
        if(  (abs1(Z1[i])>eps)   ||   (abs1(Z2[i])>eps)  )
        {
            Z1_true[N_true] = Z1[i];
            Z2_true[N_true] = Z2[i];
            N_true++;
        }
    }
    
    double *cof_x = (double *)malloc(sizeof(double)*N_true*2);
    double *cof_y = (double *)malloc(sizeof(double)*N_true*2);
    
    // Case1: rotation matrix. Compute V1 and its objective function, V = [c s; -s c];
    for(i=0;i<N_true;i++){cof_x[i] = Z1_true[i]; cof_x[i+N_true] = Z2_true[i]; cof_y[i] = Z2_true[i]; cof_y[i+N_true] = -Z1_true[i];}
    cof_a = P[0] + P[3]; cof_b = P[2] - P[1]; cof_c =  0.5*Q[0] + Q[12] + 0.5*Q[15]; cof_d =  - Q[4] + Q[8] -Q[13] + Q[14]; cof_e =  0.5*Q[5] - Q[9] + 0.5*Q[10];
    nonconvex_quadratic_trigonometry_L1(cof_a,cof_b,cof_c,cof_d,cof_e,lambda,cof_x, cof_y, N_true*2, &best_c1,&best_s1,&best_f1);
    
    // Case2: reflection matrix. Compute V2 and its objective function, V = [ -c s; s c];
    for(i=0;i<N_true;i++){cof_x[i] = -Z1_true[i]; cof_x[i+N_true] = Z2_true[i]; cof_y[i] = Z2_true[i];  cof_y[i+N_true] = Z1_true[i];}
    cof_a  = -P[0] + P[3]; cof_b  =  P[1] + P[2]; cof_c = 0.5*Q[0] - Q[12] + 0.5*Q[15]; cof_d = - Q[4] - Q[8] +  Q[13] + Q[14]; cof_e = 0.5*Q[5] + Q[9] + 0.5*Q[10];
    nonconvex_quadratic_trigonometry_L1(cof_a,cof_b,cof_c,cof_d,cof_e,lambda,cof_x, cof_y, N_true*2, &best_c2,&best_s2,&best_f2);
    
    if(best_f1<best_f2){V[0] = best_c1; V[1] = -best_s1; V[2] = best_s1; V[3] = best_c1;}
    else {V[0] = -best_c2; V[1] = best_s2; V[2] = best_s2; V[3] = best_c2;}
       
    free(Z1_true);free(Z2_true); free(cof_x); free(cof_y);
}
double dot1(const double*a,const double*b,const ptrdiff_t n)
{
    double ret = 0;
    ptrdiff_t i;
    for(i=0;i<n;i++) ret +=a[i]*b[i];
    return ret;
}

void wws_greedy_maximum_violating_pair_cc(double *X_transpose,double*G_transpose,int n, int r,int P, double*B)
{
    double max_val = -LARGE;
    B[0]=1; B[1]=2;
    int test_time = 0;
    P = min1(P,n);
    while (test_time<P)
    {
            int i_test = 1+rand()%n;
            int j_test = 1+rand()%n;
            if(i_test==j_test)continue;
            double *Gi = G_transpose + r*(i_test-1);
            double *Gj = G_transpose + r*(j_test-1);
            double *Xi = X_transpose + r*(i_test-1);
            double *Xj = X_transpose + r*(j_test-1);
            double Dij = dot1(Xi,Gj,r) - dot1(Gi,Xj,r);
            Dij = abs1(Dij);
            if(Dij>max_val)
            {   max_val = Dij;
                B[0] = i_test; B[1] = j_test;
            }
            test_time = test_time + 1;
    }
}


         
            
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    
    double *X_transpose       =  mxGetPr(prhs[0]);
    double *G_transpose       =  mxGetPr(prhs[1]);
    int n                     =  (int)mxGetScalar(prhs[2]);
    int r                     =  (int)mxGetScalar(prhs[3]);
    int P                     =  (int)mxGetScalar(prhs[4]);
    
    plhs[0] = mxCreateDoubleMatrix(2,1,mxREAL);
    double *B = mxGetPr(plhs[0]);
    wws_greedy_maximum_violating_pair_cc(X_transpose,G_transpose,n,r,P,B);
    
}
