// Usage
// 
// ./a.out <num_iterations> <num_test_tasks> <first_test_task> <data_directory> <num_clusters>

#include <algorithm>
#include <iostream>
#include <math.h>
#include <numeric> 
#include <vector>

#include "csv-reader.h"
#include "forecaster.h"
#include "hungarian.h"

struct AlphaState {
    std::vector<double> zs; // pre-sampled randomness
    std::vector<std::vector<double>> D; // distance matrix
    std::vector<int> centers; // indices of the currently chosen centers
    int next_center; // zero-based index of next center number we need to choose
    std::vector<double> distance_to_centers;
};

// Constructs the initial state for choosing `k` centers from points with distance
// matrix `D`.

AlphaState InitAlphaState(std::vector<std::vector<double>> D, int k) {
    int n = D.size();
    double maxD = 0.0;
    for (const auto &row : D) {
        for (const auto &d : row) {
            maxD = std::max(maxD, d);
        }
    }
    std::vector<double> dcs;
    for (int i = 0; i<n; ++i) {
        dcs.push_back(maxD);
    }
    std::vector<double> zs;
    for (int i = 0; i<k; ++i) {
        zs.push_back(rand01());
    }
    AlphaState as;
    as.distance_to_centers = dcs;
    as.next_center = 0;
    as.D = D;
    as.zs = zs;
    as.centers.clear();
    return as;
}

// Computes the CDF of the density that outputs outcome `j` with probability
// proportional to `weights[j]^alpha` at the value `i` (i.e., the probability of
// outputting an index `<= i`).
double discreteCdf(std::vector<double> weights, int i, double alpha) {
    double cumulative_weight = 0.0;
    double normalizing_constant = 0.0;
    for (int j = 0; j<weights.size(); ++j) {
        double w = pow(weights[j], alpha);
        if (j <= i) {
            cumulative_weight += w;
        }
        normalizing_constant += w;
    }
    return cumulative_weight / normalizing_constant;
}

// Returns the value of `alpha` in [`interval.lower`, `interval.upper`] for which the
// `discrete_cdf(weights, i, alpha) == z`. If this does not occur in the provided
// interval, it returns one of the interval endpoints.

double findBoundary(double z, std::vector<double> weights, int i, Interval interval) {
    double a = interval.lower;
    double b = interval.upper;
    while (b-a > 1e-11) {
        double m = (b+a)/2;
        if (discreteCdf(weights, i, m) < z) {
            b = m;
        } else {
            a = m;
        }
    }
    if (a == interval.lower) return a;
    if (b == interval.upper) return b;
    return (b+a)/2;
}


// Returns the smallest index `i` such that `discrete_cdf(weights, i, alpha) > z`.
// i.e. a sampling according to randomness `z`.
int discreteSample(double z, std::vector<double> weights, double alpha) {
    if (alpha == std::numeric_limits<double>::max()) {
        return weights.size();
    }
    double total = 0;
    double Z = 0;
    for (auto w : weights) {
        Z += pow(w, alpha);
    }
    for (int i = 0; i<weights.size(); ++i) {
        total += pow(weights[i], alpha) / Z;
        if (total > z) {
            return i;
        }
    }
    return weights.size()-1;
}

template <typename T>
std::vector<size_t> sortIndexes(const std::vector<T> &v) {

  // initialize original index locations
  std::vector<size_t> idx(v.size());
  std::iota(idx.begin(), idx.end(), 0);

  // sort indexes based on comparing values in v
  std::sort(idx.begin(), idx.end(),
       [&v](size_t i1, size_t i2) {return v[i1] < v[i2];});

  return idx;
}

/*
Returns all indices that could potentially be sampled by `discreteSample` with
parameter `z` for any `alpha` in  [`alpha_interval.lower`, `alpha_interval.upper`].
i.e. returns a list of (i, [alpha1, alpha2]) s.t. i is sampled for [alpha1, alpha2]
*/
std::vector<std::pair<int, Interval>> allSamples(double z, std::vector<double> weights, Interval alpha_interval) {
    // Indices of weights in the sorted order
    std::vector<size_t> sorted_perm = sortIndexes(weights);
    std::sort(weights.begin(), weights.end());

    // Find values of alpha for which z is on the boundary between two bins
    std::vector<std::pair<int, Interval>> results;

    double i_lo = discreteSample(z,weights,alpha_interval.lower);
    if (alpha_interval.length() == 0) {
        return results;
    }
    double i_hi = discreteSample(z,weights,alpha_interval.upper);
    double last_critical_alpha = alpha_interval.lower;
    for (int i = i_lo; i<i_hi; ++i) {
        double critical_alpha = findBoundary(z, weights, i, alpha_interval);
        if (last_critical_alpha != critical_alpha) {
            results.push_back(std::make_pair(sorted_perm[i], Interval(last_critical_alpha, critical_alpha)));
        }
        last_critical_alpha = critical_alpha;
        if (alpha_interval.lower > critical_alpha || critical_alpha > alpha_interval.upper) {
            printf("critical_alpha outside alpha_interval!");
        }
    }
    if (last_critical_alpha != alpha_interval.upper) {
        results.push_back(std::make_pair(sorted_perm[i_hi], Interval(last_critical_alpha, alpha_interval.upper)));
    }
    return results;
}


// Samples the randomness `zs` and returns every possible `AlphaSamplingState`
// reachable using d^alpha sampling for any value of `alpha` in the interval
// [`alpha_min`, `alpha_max`].
std::vector<std::pair<AlphaState, Interval>> alphaSamplingExhaustive(
    std::vector<std::vector<double>> D, int k, double alpha_min, double alpha_max) {
    auto step = [](AlphaState state, Interval params) {
        int nc = state.next_center;  // next center number
        auto next_centers = allSamples(state.zs[nc], state.distance_to_centers, params);
        std::vector<std::pair<AlphaState, Interval>> output_state_intervals;
        for (auto next_center : next_centers) {
            std::vector<int> centers = state.centers;  // copy of current centers
            centers.push_back(next_center.first);  // case: given next_center
            auto dist_to_centers = state.distance_to_centers;
            for (int k = 0; k < dist_to_centers.size(); ++k) {
                dist_to_centers[k] = std::min(dist_to_centers[k], state.D[k][centers.back()]);
            }
            AlphaState as = state;
            as.centers = centers;
            as.next_center = nc+1;
            as.distance_to_centers = dist_to_centers;
            output_state_intervals.push_back(std::make_pair(as, next_center.second));
        }
        return output_state_intervals;
    };

    auto terminated = [](AlphaState state) {
        return state.next_center >= state.zs.size();
    };

    std::function<std::vector<std::pair<AlphaState, Interval>>(AlphaState, Interval)> recur;
    recur = [&step, &terminated, &recur](AlphaState state, Interval params)
        -> std::vector<std::pair<AlphaState, Interval>> {
        if (terminated(state)) return {std::make_pair(state, params)};
        else {
            std::vector<std::pair<AlphaState, Interval>> leaves;
            for (auto child_state_params : step(state, params)) {
                auto recurred_list = recur(child_state_params.first, child_state_params.second);
                leaves.insert(leaves.end(), recurred_list.begin(), recurred_list.end());
            }
            return leaves;
        }
    };

    auto state = InitAlphaState(D, k);
    return recur(state, Interval(alpha_min, alpha_max));
}

std::vector<int> alphaSampling(std::vector<std::vector<double>> D, std::vector<int> centers) {
    int n = D.size();
    std::vector<int> clusters;
    double maxD = 0.0;
    for (const auto &row : D) {
        for (const auto &d : row) {
            maxD = std::max(maxD, d);
        }
    }
    std::vector<double> dcs;
    for (int i = 0; i<n; ++i) {
        dcs.push_back(maxD);
        clusters.push_back(-1);
    }
    for (int i = 0; i < centers.size(); ++i) {
        for (int j = 0; j < n; ++j) {
            if (D[j][centers[i]] < dcs[j]) {
                clusters[j] = i;
                dcs[j] = D[j][centers[i]];
            }
        }
    }
    return clusters;
}


// Samples `k` initial centers using d^`alpha` sampling. Returns a clustering mapping
// data points to centers.
std::vector<int> alphaSampling(std::vector<std::vector<double>> D, int k, double alpha) {
    int n = D.size();
    std::vector<double> zs;
    std::vector<int> clusters;
    for (int i = 0; i < k; ++i) {
        zs.push_back(rand01());
    }
    double maxD = 0.0;
    for (const auto &row : D) {
        for (const auto &d : row) {
            maxD = std::max(maxD, d);
        }
    }
    std::vector<int> centers;
    std::vector<double> dcs;
    for (int i = 0; i<n; ++i) {
        dcs.push_back(maxD);
        clusters.push_back(-1);
    }

    for (int i = 0; i < k; ++i) {
        centers.push_back(discreteSample(zs[i], dcs, alpha));
        for (int j = 0; j < n; ++j) {
            if (D[j][centers[i]] < dcs[j]) {
                clusters[j] = centers[i];
                dcs[j] = D[j][centers[i]];
            }
        }
    }
    return clusters;
}

double hammingCost(int k, std::vector<int> ys, std::vector<int> qs) {
    int n = ys.size();
    std::vector<std::vector<double>> cost_matrix;
    std::vector<double> cm;
    for (int j = 0; j < k; ++j) {
        cm.push_back(0.0);
    }
    for (int i = 0; i < k; ++i) {
        cost_matrix.push_back(cm);
    }
    for (int i = 0; i < n; ++i) {
        for (int y = 0; y < k; ++y) {
            if (y != ys[i]) cost_matrix[qs[i]][y] += 1;
        }
    }
    vector<int> assignment;
    HungarianAlgorithm HungAlgo;
    double cost = HungAlgo.Solve(cost_matrix, assignment);
    return 1 - cost / n;
}

std::vector<std::vector<double>> readData(string data_dir, int i, int T, int iter) {
    string fname = data_dir+std::to_string(iter)+"-"+std::to_string(i)+"-of-"+std::to_string(T-1);
    std::cout<< fname;
    CSVReader reader(fname);
    return reader.getData();
}

std::vector<int> readLabels(string data_dir, int i, int T, int iter) {
    string fname = data_dir+std::to_string(iter)+"-"+std::to_string(i)+"-of-"+std::to_string(T-1);
    CSVReader reader(fname);
    return reader.getLabels();
}

PiecewiseConstantFunction piecewiseHamming(std::vector<std::pair<AlphaState, Interval>> &res,
  std::vector<int> y, std::vector<std::vector<double>> D, int k) {
    std::vector<ConstantFunction> pieces;
    for (auto r : res) {
        double alpha = (r.second.lower + r.second.upper)/2;
        std::vector<int> clusters = alphaSampling(D, r.first.centers);
        std::vector<int> y_copy={};
        for (auto a:y) {y_copy.push_back(a);}
        double h = hammingCost(k, y_copy, clusters);
        pieces.push_back(ConstantFunction(h,r.second.lower,r.second.upper));
    }
    PiecewiseConstantFunction u(&pieces);
    return u;
}

double mean(std::vector<double> list) {
    double sum = 0.0;
    for (auto l : list) {
        sum+=l;
    }
    return sum/list.size();
}

double stdev(std::vector<double> list) {
    double sum = 0.0;
    double m = mean(list);
    for (auto l : list) {
        sum+=(l-m)*(l-m);
    }
    return sqrt(sum/list.size());
}

int main(int argc, char *argv[]) {
std::string data_dir = argv[4];
srand(time(NULL)); 
int k = std::stoi(argv[5]);
double a_min = 0.0, a_max = 10.0;
int iter = std::stoi(argv[2]);
Interval I(a_min,a_max);
std::vector<double> one_shots;
std::vector<double> five_shots;
for (int IT = 0; IT<std::stoi(argv[1]); ++IT) {
    for (int it = std::stoi(argv[3]); it < std::stoi(argv[3])+iter; ++it){
        int T = 30;
        ExponentialForecaster<PiecewiseConstantFunction> ef1(/*R=*/(a_max-a_min),
        /*L=*/0,/*H=*/1,/*T=*/T,/*w=*/0.5,/*d=*/1,/*I=*/I);
        OfflineOptimalForecaster<PiecewiseConstantFunction> oo2(/*I=*/I);
        PiecewiseConstantFunction U(I);
        for (int i = 0; i < T; ++i) {
            if (i == 5) break;
            int j = i;
            std::vector<std::vector<double>> D = readData(data_dir, j, T, it);
            std::vector<int> y = readLabels(data_dir, j, T,it);
            std::vector<std::pair<AlphaState, Interval>> res(alphaSamplingExhaustive(D, /*k=*/k, a_min, a_max));
            // Compute a piecewise constant function for hamming cost of alpha.
            PiecewiseConstantFunction u = piecewiseHamming(res, y, D, k);
            double rho1 = ef1.Forecast(u);
            double   opt_arg = oo2.Forecast(u);
            if (i==0) one_shots.push_back(oo2.payoff()-ef1.payoff());
            if (i==4) five_shots.push_back((oo2.payoff()-ef1.payoff())/5);
        }
    }
}
std::cout<<"one_shots: "<<mean(one_shots)<<" +/- "<<stdev(one_shots);
std::cout<<"five_shots: "<<mean(five_shots)<<" +/- "<<stdev(five_shots);
  
  return 0;
}