#include "grad_desc_dense.hpp"
#include "utils.hpp"
#include "regularizer.hpp"
#include <algorithm>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <iostream>
#include <random>
#include <string.h>
#include <sys/time.h>
#include <stdlib.h>
#include <chrono>

extern int MAX_DIM;
grad_desc_dense::outputs grad_desc_dense::SAGA(double* X, double* Y, int N
    , blackbox* model, int iteration_no, double step_size) {
    // Random Generator
    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine generator(seed);
    std::uniform_int_distribution<int> distribution(0, N - 1);
    std::vector<double>* losses = new std::vector<double>;
    std::vector<double>* times = new std::vector<double>;
    struct timeval tp;
    long int start_ms = 0;
    int regular = model->get_regularizer();
    double* lambda = model->get_params();

    // Store results
    losses->push_back(model->zero_oracle_dense(X, Y, N));
    // Extra Pass for Create Gradient Table
    losses->push_back((*losses)[0]);
    times->push_back(0);
    gettimeofday(&tp, NULL);
    start_ms = tp.tv_sec * 1000 + tp.tv_usec / 1000;

    double* new_weights = new double[MAX_DIM];
    double* grad_core_table = new double[N];
    double* aver_grad = new double[MAX_DIM];
    copy_vec(new_weights, model->get_model());
    memset(aver_grad, 0, MAX_DIM * sizeof(double));
    // Init Gradient Core Table
    for(int i = 0; i < N; i ++) {
        grad_core_table[i] = model->first_component_oracle_core_dense(X, Y, N, i);
        for(int j = 0; j < MAX_DIM; j ++)
            aver_grad[j] += grad_core_table[i] * X[i * MAX_DIM + j] / N;
    }
    // First pass initialization
    gettimeofday(&tp, NULL);
    times->push_back(tp.tv_sec * 1000 + tp.tv_usec / 1000 - start_ms);
    for(int i = 0; i < iteration_no; i ++) {
        int rand_samp = distribution(generator);
        double core = model->first_component_oracle_core_dense(X, Y, N, rand_samp, new_weights);
        double past_grad_core = grad_core_table[rand_samp];
        grad_core_table[rand_samp] = core;
        for(int j = 0; j < MAX_DIM; j ++) {
            // Update Weight
            new_weights[j] -= step_size * ((core - past_grad_core)* X[rand_samp * MAX_DIM + j]
                            + aver_grad[j]);
            // Update Gradient Table Average
            aver_grad[j] -= (past_grad_core - core) * X[rand_samp * MAX_DIM + j] / N;
            regularizer::proximal_operator(regular, new_weights[j], step_size, lambda);
        }
        // Store results
        if(!((i + 1) % N)) {
            losses->push_back(model->zero_oracle_dense(X, Y, N, new_weights));
            gettimeofday(&tp, NULL);
            times->push_back(tp.tv_sec * 1000 + tp.tv_usec / 1000 - start_ms);
        }
    }
    model->update_model(new_weights);
    delete[] new_weights;
    delete[] grad_core_table;
    delete[] aver_grad;
    return grad_desc_dense::outputs(losses, times);
}

grad_desc_dense::outputs grad_desc_dense::Prox_SVRG(double* X, double* Y, int N, blackbox* model
    , int iteration_no, int Mode, double L, double step_size) {
    // Random Generator
    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine generator(seed);
    std::uniform_int_distribution<int> distribution(0, N - 1);
    // std::vector<double>* stored_weights = new std::vector<double>;
    std::vector<double>* losses = new std::vector<double>;
    std::vector<double>* times = new std::vector<double>;
    double* inner_weights = new double[MAX_DIM];
    double* full_grad = new double[MAX_DIM];
    //FIXME: Epoch Size(SVRG / SVRG++)
    double m0 = (double) N * 2.0;
    int regular = model->get_regularizer();
    double* lambda = model->get_params();
    int total_iterations = 0;
    copy_vec(inner_weights, model->get_model());
    // Init Weight Evaluate
    losses->push_back(model->zero_oracle_dense(X, Y, N));
    // OUTTER_LOOP
    for(int i = 0 ; i < iteration_no; i ++) {
        double* full_grad_core = new double[N];
        // Average Iterates
        double* aver_weights = new double[MAX_DIM];
        //FIXME: SVRG / SVRG++
        double inner_m = m0;//pow(2, i + 1) * m0;
        memset(aver_weights, 0, MAX_DIM * sizeof(double));
        memset(full_grad, 0, MAX_DIM * sizeof(double));
        // Full Gradient
        for(int j = 0; j < N; j ++) {
            full_grad_core[j] = model->first_component_oracle_core_dense(X, Y, N, j);
            for(int k = 0; k < MAX_DIM; k ++) {
                full_grad[k] += X[j * MAX_DIM + k] * full_grad_core[j] / (double) N;
            }
        }
        switch(Mode) {
            case SVRG_LAST_LAST:
            case SVRG_AVER_LAST:
                break;
            case SVRG_AVER_AVER:
                copy_vec(inner_weights, model->get_model());
                break;
            default:
                throw std::string("500 Internal Error.");
                break;
        }
        // INNER_LOOP
        for(int j = 0; j < inner_m; j ++) {
            int rand_samp = distribution(generator);
            double inner_core = model->first_component_oracle_core_dense(X, Y, N
                , rand_samp, inner_weights);
            for(int k = 0; k < MAX_DIM; k ++) {
                double val = X[rand_samp * MAX_DIM + k];
                double vr_sub_grad = (inner_core - full_grad_core[rand_samp]) * val + full_grad[k];
                inner_weights[k] -= step_size * vr_sub_grad;
                aver_weights[k] += regularizer::proximal_operator(regular, inner_weights[k], step_size, lambda) / inner_m;
            }
            total_iterations ++;
        }
        switch(Mode) {
            case SVRG_LAST_LAST:
                model->update_model(inner_weights);
                break;
            case SVRG_AVER_LAST:
            case SVRG_AVER_AVER:
                model->update_model(aver_weights);
                break;
            default:
                throw std::string("500 Internal Error.");
                break;
        }
        losses->push_back(model->zero_oracle_dense(X, Y, N));
        delete[] aver_weights;
        delete[] full_grad_core;
    }
    delete[] full_grad;
    delete[] inner_weights;
    return grad_desc_dense::outputs(losses, times);
}

grad_desc_dense::outputs grad_desc_dense::L2S(double* X, double* Y, int N
    , blackbox* model, int iteration_no, double step_size) {
    // Sample Generator
    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine generator(seed);
    std::uniform_int_distribution<int> distribution(0, N - 1);

    // Random Update
    unsigned seed2 = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine gen2(seed2);
    std::bernoulli_distribution b_dist(1.0 / (double) N);

    std::vector<double>* losses = new std::vector<double>;
    std::vector<double>* times = new std::vector<double>;

    double* lambda = model->get_params();

    double* x = new double[MAX_DIM];
    double* prev_x = new double[MAX_DIM];
    double* v = new double[MAX_DIM];
    copy_vec(x, model->get_model());
    memset(v, 0, MAX_DIM * sizeof(double));
    double mingrad = 99999999;
    int pass_counter = 0;

    // Full gradient
    for(int j = 0; j < N; j ++) {
        double core = model->first_component_oracle_core_dense(X, Y, N, j, x);
        for(int k = 0; k < MAX_DIM; k ++) {
            v[k] += X[j * MAX_DIM + k] * core / (double) N;
        }
    }
    pass_counter += N;
    double gradnorm = 0;
    for(int j = 0; j < MAX_DIM; j ++)
        gradnorm += v[j] * v[j];
    mingrad = min(mingrad, sqrt(gradnorm));

    losses->push_back(mingrad);

    // The first full gradient descent
    for(int j = 0; j < MAX_DIM; j ++) {
        prev_x[j] = x[j];
        x[j] -= step_size * v[j];
    }
    losses->push_back(mingrad);

    int epoch_counter = 1;
    while(int((double) pass_counter / (double) N) < iteration_no) {
        // With probability 1/N.
        if(b_dist(gen2)) {
            // Full gradient
            memset(v, 0, MAX_DIM * sizeof(double));
            for(int j = 0; j < N; j ++) {
                double core = model->first_component_oracle_core_dense(X, Y, N, j, x);
                for(int k = 0; k < MAX_DIM; k ++) {
                    v[k] += X[j * MAX_DIM + k] * core / (double) N;
                }
            }
            double gradnorm = 0;
            for(int j = 0; j < MAX_DIM; j ++)
                gradnorm += v[j] * v[j];
            mingrad = min(mingrad, sqrt(gradnorm));
            losses->push_back(mingrad);            
            pass_counter += N;
            epoch_counter ++;
        }
        // With probability 1 - 1/N.
        else {
            int rand_samp = distribution(generator);
            double core = model->first_component_oracle_core_dense(X, Y, N, rand_samp, x);
            double prev_core = model->first_component_oracle_core_dense(X, Y, N, rand_samp, prev_x);
  
            for(int j = 0; j < MAX_DIM; j ++) 
                v[j] = (core - prev_core) * X[rand_samp * MAX_DIM + j] + v[j];
            
            pass_counter += 2;
            if(int((double) pass_counter / (double) N) >= epoch_counter + 1) {
                losses->push_back(mingrad);
                epoch_counter ++;
            }
        }
        // Update x
        for(int j = 0; j < MAX_DIM; j ++) {
            prev_x[j] = x[j];
            x[j] -= step_size * v[j];
        }
    }
    model->update_model(x);
    delete[] x;
    delete[] prev_x;
    delete[] v;
    return grad_desc_dense::outputs(losses, times);
}

// Output gradient norm
grad_desc_dense::outputs grad_desc_dense::Acc_SVRG_G(double* X, double* Y, int N
    , blackbox* model, int iteration_no, double L) {
    // Sample Generator
    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine generator(seed);
    std::uniform_int_distribution<int> distribution(0, N - 1);

    // Random Update
    unsigned seed2 = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine gen2(seed2);

    std::vector<double>* losses = new std::vector<double>;
    std::vector<double>* times = new std::vector<double>;

    double* lambda = model->get_params();

    double* x = new double[MAX_DIM];
    double* y = new double[MAX_DIM];
    double* z = new double[MAX_DIM];
    double* full_grad = new double[MAX_DIM];
    copy_vec(x, model->get_model());
    copy_vec(y, model->get_model());
    copy_vec(z, model->get_model());
    double mingrad = 99999999;
    int pass_counter = 0;

    // Full gradient
    memset(full_grad, 0, MAX_DIM * sizeof(double));
    for(int j = 0; j < N; j ++) {
        double core = model->first_component_oracle_core_dense(X, Y, N, j, x);
        for(int k = 0; k < MAX_DIM; k ++) {
            full_grad[k] += X[j * MAX_DIM + k] * core / (double) N;
        }
    }
    pass_counter += N;
    double gradnorm = 0;
    for(int j = 0; j < MAX_DIM; j ++)
        gradnorm += full_grad[j] * full_grad[j];
    mingrad = min(mingrad, sqrt(gradnorm));

    losses->push_back(mingrad);

    // One pass for initialization
    losses->push_back(mingrad);

    int epoch_counter = 1;
    int iter_counter = 0;
    while(int((double) pass_counter / (double) N) < iteration_no) {
        double pk = max(6.0 / ((double) iter_counter + 8.0), 1.0 / (double) N);
        double tauk = 3.0 / (pk * ((double) iter_counter + 8.0));

        for(int j = 0; j < MAX_DIM; j ++) {
            y[j] = tauk * z[j] + (1 - tauk) * (x[j] - 1.0 / L * full_grad[j]);
        }

        int rand_samp = distribution(generator);
        double ycore = model->first_component_oracle_core_dense(X, Y, N, rand_samp, y);
        double snap_core = model->first_component_oracle_core_dense(X, Y, N, rand_samp, x);
        pass_counter += 2;
        if(int((double) pass_counter / (double) N) >= epoch_counter + 1) {
            losses->push_back(mingrad);
            epoch_counter ++;
        }

        for(int j = 0; j < MAX_DIM; j ++) 
            z[j] -= (1 - tauk) / (L * tauk) * ((ycore - snap_core) * X[rand_samp * MAX_DIM + j] + full_grad[j]);

        std::bernoulli_distribution b_dist(pk);
        // With probability pk.
        if(b_dist(gen2)) {
            for(int j = 0; j < MAX_DIM; j ++)
                x[j] = y[j];

            // Full gradient
            memset(full_grad, 0, MAX_DIM * sizeof(double));
            for(int j = 0; j < N; j ++) {
                double core = model->first_component_oracle_core_dense(X, Y, N, j, x);
                for(int k = 0; k < MAX_DIM; k ++) {
                    full_grad[k] += X[j * MAX_DIM + k] * core / (double) N;
                }
            }
            double gradnorm = 0;
            for(int j = 0; j < MAX_DIM; j ++)
                gradnorm += full_grad[j] * full_grad[j];
            mingrad = min(mingrad, sqrt(gradnorm));
            losses->push_back(mingrad);            
            pass_counter += N;
            epoch_counter ++;
        }
    
        iter_counter ++;
    }
    model->update_model(x);
    delete[] x;
    delete[] y;
    delete[] z;
    delete[] full_grad;
    return grad_desc_dense::outputs(losses, times);
}

// Output Function value
// grad_desc_dense::outputs grad_desc_dense::Acc_SVRG_G(double* X, double* Y, int N
//     , blackbox* model, int iteration_no, double L) {
//     // Sample Generator
//     unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
//     std::default_random_engine generator(seed);
//     std::uniform_int_distribution<int> distribution(0, N - 1);

//     // Random Update
//     unsigned seed2 = std::chrono::system_clock::now().time_since_epoch().count();
//     std::default_random_engine gen2(seed2);

//     std::vector<double>* losses = new std::vector<double>;
//     std::vector<double>* times = new std::vector<double>;

//     double* lambda = model->get_params();

//     double* x = new double[MAX_DIM];
//     double* y = new double[MAX_DIM];
//     double* z = new double[MAX_DIM];
//     double* full_grad = new double[MAX_DIM];
//     copy_vec(x, model->get_model());
//     copy_vec(y, model->get_model());
//     copy_vec(z, model->get_model());
//     int pass_counter = 0;

//     // Full gradient
//     memset(full_grad, 0, MAX_DIM * sizeof(double));
//     for(int j = 0; j < N; j ++) {
//         double core = model->first_component_oracle_core_dense(X, Y, N, j, x);
//         for(int k = 0; k < MAX_DIM; k ++) {
//             full_grad[k] += X[j * MAX_DIM + k] * core / (double) N;
//         }
//     }
//     pass_counter += N;
//     losses->push_back(model->zero_oracle_dense(X, Y, N, x));

//     // One pass for initialization
//     losses->push_back(model->zero_oracle_dense(X, Y, N, x));

//     int epoch_counter = 1;
//     int iter_counter = 0;
//     while(int((double) pass_counter / (double) N) < iteration_no) {
//         double pk = max(6.0 / ((double) iter_counter + 8.0), 1.0 / (double) N);
//         double tauk = 3.0 / (pk * ((double) iter_counter + 8.0));

//         for(int j = 0; j < MAX_DIM; j ++) {
//             y[j] = tauk * z[j] + (1 - tauk) * (x[j] - 1.0 / L * full_grad[j]);
//         }

//         int rand_samp = distribution(generator);
//         double ycore = model->first_component_oracle_core_dense(X, Y, N, rand_samp, y);
//         double snap_core = model->first_component_oracle_core_dense(X, Y, N, rand_samp, x);
//         pass_counter += 2;
//         if(int((double) pass_counter / (double) N) >= epoch_counter + 1) {
//             losses->push_back(model->zero_oracle_dense(X, Y, N, x));
//             epoch_counter ++;
//         }

//         for(int j = 0; j < MAX_DIM; j ++) 
//             z[j] -= (1 - tauk) / (L * tauk) * ((ycore - snap_core) * X[rand_samp * MAX_DIM + j] + full_grad[j]);

//         std::bernoulli_distribution b_dist(pk);
//         // With probability pk.
//         if(b_dist(gen2)) {
//             for(int j = 0; j < MAX_DIM; j ++)
//                 x[j] = y[j];

//             // Full gradient
//             memset(full_grad, 0, MAX_DIM * sizeof(double));
//             for(int j = 0; j < N; j ++) {
//                 double core = model->first_component_oracle_core_dense(X, Y, N, j, x);
//                 for(int k = 0; k < MAX_DIM; k ++) {
//                     full_grad[k] += X[j * MAX_DIM + k] * core / (double) N;
//                 }
//             }
//             losses->push_back(model->zero_oracle_dense(X, Y, N, x));            
//             pass_counter += N;
//             epoch_counter ++;
//         }
//         iter_counter ++;
//     }
//     model->update_model(x);
//     delete[] x;
//     delete[] y;
//     delete[] z;
//     delete[] full_grad;
//     return grad_desc_dense::outputs(losses, times);
// }