#include "mex.h"
#include <math.h>
#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <blas.h>    // dgemv
#include <lapack.h>  // dstegr
#include <string.h>
#include <omp.h>
#include <iostream>
//#include <iomanip>
// #include <blas.h>    // dgemv
// #include <lapack.h>  // dstegr
using namespace std;
//#include <iomanip>
// #include <blas.h>    // dgemv
// #include <lapack.h>  // dstegr



#define  abs1(a)         ((a) < 0.0 ? -(a) : (a))
#define  sign1(a)        ((a)==0) ? 0 : (((a)>0.0)?1:(-1))
#define  max1(a,b)       ((a) > (b) ? (a) : (b))
#define  min1(a,b)       ((a) < (b) ? (a) : (b))
# define PI 3.141592653589793238463L
# define M_2PI (2 * PI)
# define eps 2.2204e-16
# define LARGE 1e100

void pvec(double *x,int n)
{
    int i;
    for(i=0;i<n;i++)
    {
        printf("%f ",x[i]);
    }
    
    printf("\n ");
    
}

void CreatHeap(double a[], int i, int  n)
{
    for (; i >= 0; --i)
    {
        int left = i * 2 + 1; int right = i * 2 + 2; int j = 0;
        if (right < n)
        {
            if(a[left] > a[right])
                j = left;
            else
                j = right;
        }
        else
            j = left;
        if (a[j] > a[i])
        {
            double tmp = a[i];
            a[i] = a[j];
            a[j] = tmp;
        }
    }
}

void HeapSort(double a[], int n)
{
    // In increasing order
    CreatHeap(a, n/2-1, n);
    for (int j = n-1; j >= 0; j--)
    {
        
        double tmp = a[0];
        a[0] = a[j];
        a[j] = tmp;
        
        int i = j / 2 - 1;
        CreatHeap(a, i, j);
    }
}


//---------------------------------------------------------------------------
// solve cubic equation x^3 + a*x^2 + b*x + c
// x - array of size 3
// In case 3 real roots: => x[0], x[1], x[2], return 3
//         2 real roots: x[0], x[1],          return 2
//         1 real root : x[0], x[1] ¡À i*x[2], return 1
unsigned int solveP3(double* x, double a, double b, double c)
{
    double a2 = a * a;
    double q = (a2 - 3 * b) / 9;
    double r = (a * (2 * a2 - 9 * b) + 27 * c) / 54;
    double r2 = r * r;
    double q3 = q * q * q;
    double A, B;
    if (r2 < q3)
    {
        double t = r / sqrt(q3);
        if (t < -1) t = -1;
        if (t > 1) t = 1;
        t = acos(t);
        a /= 3; q = -2 * sqrt(q);
        x[0] = q * cos(t / 3) - a;
        x[1] = q * cos((t + M_2PI) / 3) - a;
        x[2] = q * cos((t - M_2PI) / 3) - a;
        return 3;
    }
    else
    {
        A = -pow(abs1(r) + sqrt(r2 - q3), 1. / 3);
        if (r < 0) A = -A;
        B = (0 == A ? 0 : q / A);
        
        a /= 3;
        x[0] = (A + B) - a;
        x[1] = -0.5 * (A + B) - a;
        x[2] = 0.5 * sqrt(3.) * (A - B);
        if (abs1(x[2]) < eps) { x[2] = x[1]; return 2; }
        
        return 1;
    }
}

//---------------------------------------------------------------------------
// Solve quartic equation x^4 + a*x^3 + b*x^2 + c*x + d
// (attention - this function returns dynamically allocated array. It has to be released afterwards)
int solve_quartic_2(double a, double b, double c, double d,double *x_sol)
{
    double a3 = -b;
    double b3 = a * c - 4. * d;
    double c3 = -a * a * d - c * c + 4. * b * d;
    int num = 0;
    
    // cubic resolvent
    // y^3 ? b*y^2 + (ac?4d)*y ? a^2*d?c^2+4*b*d = 0
    
    double x3[3];
    unsigned int iZeroes = solveP3(x3, a3, b3, c3);
    
    double q1, q2, p1, p2, D, sqD, y;
    
    y = x3[0];
    // THE ESSENCE - choosing Y with maximal absolute value !
    if (iZeroes != 1)
    {
        if (abs1(x3[1]) > abs1(y)) y = x3[1];
        if (abs1(x3[2]) > abs1(y)) y = x3[2];
    }
    
    // h1+h2 = y && h1*h2 = d  <=>  h^2 -y*h + d = 0    (h === q)
    
    D = y * y - 4 * d;
    if (abs1(D) < eps) //in other words - D==0
    {
        q1 = q2 = y * 0.5;
        // g1+g2 = a && g1+g2 = b-y   <=>   g^2 - a*g + b-y = 0    (p === g)
        D = a * a - 4 * (b - y);
        if (abs1(D) < eps) //in other words - D==0
            p1 = p2 = a * 0.5;
        
        else
        {
            sqD = sqrt(D);
            p1 = (a + sqD) * 0.5;
            p2 = (a - sqD) * 0.5;
        }
    }
    else
    {
        sqD = sqrt(D);
        q1 = (y + sqD) * 0.5;
        q2 = (y - sqD) * 0.5;
        // g1+g2 = a && g1*h2 + g2*h1 = c       ( && g === p )  Krammer
        p1 = (a * q1 - c) / (q1 - q2);
        p2 = (c - a * q2) / (q1 - q2);
    }
    
    //DComplex* retval = new DComplex[4];
    
    // solving quadratic eq. - x^2 + p1*x + q1 = 0
    D = p1 * p1 - 4 * q1;
    if (D < 0.0)
    {
        *x_sol = -p1 * 0.5;
        x_sol ++; num = num + 1;
//         printf("%f\n",-p1 * 0.5);
//		retval[0].real(-p1 * 0.5);
//		retval[0].imag(sqrt(-D) * 0.5);
//		retval[1] = std::conj(retval[0]);
    }
    else
    {
        sqD = sqrt(D);
        *x_sol = (-p1 + sqD) * 0.5; x_sol ++;
        *x_sol = (-p1 - sqD) * 0.5; x_sol ++;
        num = num + 2;
        
//          printf("%f\n",(-p1 + sqD) * 0.5);
//          printf("%f\n",(-p1 - sqD) * 0.5);
        
//		retval[0].real((-p1 + sqD) * 0.5);
//		retval[1].real((-p1 - sqD) * 0.5);
    }
    
    // solving quadratic eq. - x^2 + p2*x + q2 = 0
    D = p2 * p2 - 4 * q2;
    if (D < 0.0)
    {
        *x_sol = -p2 * 0.5; x_sol ++;
        num = num + 1;
//         printf("%f\n",-p2 * 0.5);
//		retval[2].real(-p2 * 0.5);
//		retval[2].imag(sqrt(-D) * 0.5);
//		retval[3] = std::conj(retval[2]);
    }
    else
    {
        sqD = sqrt(D);
//          printf("%f\n",(-p2 + sqD) * 0.5);
//          printf("%f\n",(-p2 - sqD) * 0.5);
        *x_sol = (-p2 + sqD) * 0.5;  x_sol ++;
        *x_sol = (-p2 - sqD) * 0.5;  x_sol ++;
        
        num = num + 2;
//		retval[2].real((-p2 + sqD) * 0.5);
//		retval[3].real((-p2 - sqD) * 0.5);
    }
    
    
    return num;
}








int solve_quartic(double a, double b, double c, double d,double e, double *x_sol)
{
// Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
    int num=0;
    if(abs1(a)<eps)
    {
        // sover b*x^3 + c*x^2 + d*x + e = 0
        if(abs1(b)<eps)
        {
            // c*x^2 + d*x + e = 0
            if(abs1(c)<eps)
            {
                if(abs1(d)<eps)
                {
                    if(abs1(e)<eps)
                    {
                        x_sol[0] = 1;
                        num = 1;
                    }
                    else
                    {
                        num = 0;
                    }
                }
                else
                {
                    x_sol[0] = - e/d;
                    num = 1;
                }
            }
            else
            {
                // c*x^2 + d*x + e = 0
                double delta = d*d - 4*c*e;
                if(delta<0)
                {
                    num=0;
                }
                else
                {
                    x_sol[0] =  (- d + sqrt(delta)) / (2*c);
                    x_sol[1] =  (- d - sqrt(delta)) / (2*c);
                    num = 2;
                }
            }
            
        }
        else
        {
            // sover x^3 + c/b*x^2 + d/b*x + e/b = 0
            num = solveP3(x_sol, c/b, d/b, e/b);
        }
    }
    else
    {
        // Solve quartic equation ax^4 + b*x^3 + c*x^2 + d*x + e = 0
        num =  solve_quartic_2(b/a,c/a,d/a,e/a,x_sol);
    }
    return num;
}


double ComputeObj_cs(const double c,const double s,const double A,const double B,const double C,const double D,const double E,const double lambda,const double*x,const double*y,const int n)
{
// HandleObj_cs = @(c,s)A*c + B*s + C*c*c + D*c*s + E*s*s + lambda*norm(c*x + s*y,1);
    double val = 0;int i;
    for(i=0;i<n;i++)
    {
        val += abs1(c*x[i] + s*y[i]);
    }
    val = val * lambda + A*c + B*s + C*c*c + D*c*s + E*s*s;
    return val;
    
}

   
   
 
 
double dot1(const double*a,const double*b,const ptrdiff_t n)
{
    double ret = 0;
    ptrdiff_t i;
    for(i=0;i<n;i++) ret +=a[i]*b[i];
    return ret;
}

void wws_greedy_maximum_violating_pair_cc(double *X_transpose,double*G_transpose,int n, int r,int P, double*B)
{
    double min_val = LARGE;
    B[0]=1; B[1]=2;
    int test_time = 0;
    P = min1(P,n);
    while (test_time<P)
    {
            int i_test = 1+rand()%n;
            int j_test = 1+rand()%n;
            if(i_test==j_test)continue;
            double *Gi = G_transpose + r*(i_test-1);
            double *Gj = G_transpose + r*(j_test-1);
            double *Xi = X_transpose + r*(i_test-1);
            double *Xj = X_transpose + r*(j_test-1);
            double Tii = dot1(Gi,Xi,r);
            double Tjj = dot1(Gj,Xj,r);
            double Tij = dot1(Gi,Xj,r);
            double Tji = dot1(Gj,Xi,r);
            double AA = Tii + Tjj;
            double BB = Tij - Tji;
            double CC = Tjj - Tii;
            double DD = Tij + Tji;
            double w1 = - AA - sqrt(AA*AA+BB*BB);
            double w2 = - AA - sqrt(CC*CC+DD*DD);
            double Dij = min1(w1,w2);
            if(Dij<min_val)
            {   min_val = Dij;
                B[0] = i_test; B[1] = j_test;
            }
            test_time = test_time + 1;
    }
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    
    double *X_transpose       =  mxGetPr(prhs[0]);
    double *G_transpose       =  mxGetPr(prhs[1]);
    int n                     =  (int)mxGetScalar(prhs[2]);
    int r                     =  (int)mxGetScalar(prhs[3]);
    int P                     =  (int)mxGetScalar(prhs[4]);
    
    plhs[0] = mxCreateDoubleMatrix(2,1,mxREAL);
    double *B = mxGetPr(plhs[0]);
    wws_greedy_maximum_violating_pair_cc(X_transpose,G_transpose,n,r,P,B);
    
}
