// Usage
// 
// ./a.out <num_training_tasks> <num_iterations> <num_within_task_iterations> <num_test_tasks> <first_test_task> <data_directory> <num_clusters>

// g++ -o tar_omn_wrt_sample_size_per_task.out tar_omn_wrt_sample_size_per_task.cc
// ./tar_omn_wrt_sample_size_per_task.out 11 50 5 5 0 ./ANONYMIZED_OMNIGLOT 2

#include <algorithm>
#include <iostream>
#include <math.h>
#include <numeric> 
#include <vector>
#include <stdlib.h>
#include <numeric>


#include "csv-reader.h"
#include "forecaster.h"
#include "hungarian.h"
#include <functional>

struct AlphaState {
    std::vector<double> zs; // pre-sampled randomness
    std::vector<std::vector<double>> D; // distance matrix
    std::vector<int> centers; // indices of the currently chosen centers
    int next_center; // zero-based index of next center number we need to choose
    std::vector<double> distance_to_centers;
};

// Constructs the initial state for choosing `k` centers from points with distance
// matrix `D`.

AlphaState InitAlphaState(std::vector<std::vector<double>> D, int k) {
    int n = D.size();
    double maxD = 0.0;
    for (const auto &row : D) {
        for (const auto &d : row) {
            maxD = std::max(maxD, d);
        }
    }
    std::vector<double> dcs;
    for (int i = 0; i<n; ++i) {
        dcs.push_back(maxD);
    }
    std::vector<double> zs;
    for (int i = 0; i<k; ++i) {
        zs.push_back(rand01());
    }
    AlphaState as;
    as.distance_to_centers = dcs;
    as.next_center = 0;
    as.D = D;
    as.zs = zs;
    as.centers.clear();
    return as;
}

// Computes the CDF of the density that outputs outcome `j` with probability
// proportional to `weights[j]^alpha` at the value `i` (i.e., the probability of
// outputting an index `<= i`).
double discreteCdf(std::vector<double> weights, int i, double alpha) {
    double cumulative_weight = 0.0;
    double normalizing_constant = 0.0;
    for (int j = 0; j<weights.size(); ++j) {
        double w = pow(weights[j], alpha);
        if (j <= i) {
            cumulative_weight += w;
        }
        normalizing_constant += w;
    }
    return cumulative_weight / normalizing_constant;
}

// Returns the value of `alpha` in [`interval.lower`, `interval.upper`] for which the
// `discrete_cdf(weights, i, alpha) == z`. If this does not occur in the provided
// interval, it returns one of the interval endpoints.

double findBoundary(double z, std::vector<double> weights, int i, Interval interval) {
    double a = interval.lower;
    double b = interval.upper;
    while (b-a > 1e-11) {
        double m = (b+a)/2;
        if (discreteCdf(weights, i, m) < z) {
            b = m;
        } else {
            a = m;
        }
    }
    if (a == interval.lower) return a;
    if (b == interval.upper) return b;
    return (b+a)/2;
}


// Returns the smallest index `i` such that `discrete_cdf(weights, i, alpha) > z`.
// i.e. a sampling according to randomness `z`.
int discreteSample(double z, std::vector<double> weights, double alpha) {
    if (alpha == std::numeric_limits<double>::max()) {
        return weights.size();
    }
    double total = 0;
    double Z = 0;
    for (auto w : weights) {
        Z += pow(w, alpha);
    }
    for (int i = 0; i<weights.size(); ++i) {
        total += pow(weights[i], alpha) / Z;
        if (total > z) {
            return i;
        }
    }
    return weights.size()-1;
}

template <typename T>
std::vector<size_t> sortIndexes(const std::vector<T> &v) {

  // initialize original index locations
  std::vector<size_t> idx(v.size());
  std::iota(idx.begin(), idx.end(), 0);

  // sort indexes based on comparing values in v
  std::sort(idx.begin(), idx.end(),
       [&v](size_t i1, size_t i2) {return v[i1] < v[i2];});

  return idx;
}

/*
Returns all indices that could potentially be sampled by `discreteSample` with
parameter `z` for any `alpha` in  [`alpha_interval.lower`, `alpha_interval.upper`].
i.e. returns a list of (i, [alpha1, alpha2]) s.t. i is sampled for [alpha1, alpha2]
*/
std::vector<std::pair<int, Interval>> allSamples(double z, std::vector<double> weights, Interval alpha_interval) {
    // Indices of weights in the sorted order
    std::vector<size_t> sorted_perm = sortIndexes(weights);
    std::sort(weights.begin(), weights.end());

    // Find values of alpha for which z is on the boundary between two bins
    std::vector<std::pair<int, Interval>> results;

    double i_lo = discreteSample(z,weights,alpha_interval.lower);
    if (alpha_interval.length() == 0) {
        return results;
    }
    double i_hi = discreteSample(z,weights,alpha_interval.upper);
    double last_critical_alpha = alpha_interval.lower;
    for (int i = i_lo; i<i_hi; ++i) {
        double critical_alpha = findBoundary(z, weights, i, alpha_interval);
        if (last_critical_alpha != critical_alpha) {
            results.push_back(std::make_pair(sorted_perm[i], Interval(last_critical_alpha, critical_alpha)));
        }
        last_critical_alpha = critical_alpha;
        if (alpha_interval.lower > critical_alpha || critical_alpha > alpha_interval.upper) {
            printf("critical_alpha outside alpha_interval!");
        }
    }
    if (last_critical_alpha != alpha_interval.upper) {
        results.push_back(std::make_pair(sorted_perm[i_hi], Interval(last_critical_alpha, alpha_interval.upper)));
    }
    return results;
}


// Samples the randomness `zs` and returns every possible `AlphaSamplingState`
// reachable using d^alpha sampling for any value of `alpha` in the interval
// [`alpha_min`, `alpha_max`].
std::vector<std::pair<AlphaState, Interval>> alphaSamplingExhaustive(
    std::vector<std::vector<double>> D, int k, double alpha_min, double alpha_max) {
    auto step = [](AlphaState state, Interval params) {
        int nc = state.next_center;  // next center number
        auto next_centers = allSamples(state.zs[nc], state.distance_to_centers, params);
        std::vector<std::pair<AlphaState, Interval>> output_state_intervals;
        for (auto next_center : next_centers) {
            std::vector<int> centers = state.centers;  // copy of current centers
            centers.push_back(next_center.first);  // case: given next_center
            auto dist_to_centers = state.distance_to_centers;
            for (int k = 0; k < dist_to_centers.size(); ++k) {
                dist_to_centers[k] = std::min(dist_to_centers[k], state.D[k][centers.back()]);
            }
            AlphaState as = state;
            as.centers = centers;
            as.next_center = nc+1;
            as.distance_to_centers = dist_to_centers;
            output_state_intervals.push_back(std::make_pair(as, next_center.second));
        }
        return output_state_intervals;
    };

    auto terminated = [](AlphaState state) {
        return state.next_center >= state.zs.size();
    };

    std::function<std::vector<std::pair<AlphaState, Interval>>(AlphaState, Interval)> recur;
    recur = [&step, &terminated, &recur](AlphaState state, Interval params)
        -> std::vector<std::pair<AlphaState, Interval>> {
        if (terminated(state)) return {std::make_pair(state, params)};
        else {
            std::vector<std::pair<AlphaState, Interval>> leaves;
            for (auto child_state_params : step(state, params)) {
                auto recurred_list = recur(child_state_params.first, child_state_params.second);
                leaves.insert(leaves.end(), recurred_list.begin(), recurred_list.end());
            }
            return leaves;
        }
    };

    auto state = InitAlphaState(D, k);
    return recur(state, Interval(alpha_min, alpha_max));
}

std::vector<int> alphaSampling(std::vector<std::vector<double>> D, std::vector<int> centers) {
    int n = D.size();
    std::vector<int> clusters;
    double maxD = 0.0;
    for (const auto &row : D) {
        for (const auto &d : row) {
            maxD = std::max(maxD, d);
        }
    }
    std::vector<double> dcs;
    for (int i = 0; i<n; ++i) {
        dcs.push_back(maxD);
        clusters.push_back(-1);
    }
    for (int i = 0; i < centers.size(); ++i) {
        for (int j = 0; j < n; ++j) {
            if (D[j][centers[i]] < dcs[j]) {
                clusters[j] = i;
                dcs[j] = D[j][centers[i]];
            }
        }
    }
    return clusters;
}


// Samples `k` initial centers using d^`alpha` sampling. Returns a clustering mapping
// data points to centers.
std::vector<int> alphaSampling(std::vector<std::vector<double>> D, int k, double alpha) {
    int n = D.size();
    std::vector<double> zs;
    std::vector<int> clusters;
    for (int i = 0; i < k; ++i) {
        zs.push_back(rand01());
    }
    double maxD = 0.0;
    for (const auto &row : D) {
        for (const auto &d : row) {
            maxD = std::max(maxD, d);
        }
    }
    std::vector<int> centers;
    std::vector<double> dcs;
    for (int i = 0; i<n; ++i) {
        dcs.push_back(maxD);
        clusters.push_back(-1);
    }

    for (int i = 0; i < k; ++i) {
        centers.push_back(discreteSample(zs[i], dcs, alpha));
        for (int j = 0; j < n; ++j) {
            if (D[j][centers[i]] < dcs[j]) {
                clusters[j] = centers[i];
                dcs[j] = D[j][centers[i]];
            }
        }
    }
    return clusters;
}

double hammingCost(int k, std::vector<int> ys, std::vector<int> qs) {
    int n = ys.size();
    std::vector<std::vector<double>> cost_matrix;
    std::vector<double> cm;
    for (int j = 0; j < k; ++j) {
        cm.push_back(0.0);
    }
    for (int i = 0; i < k; ++i) {
        cost_matrix.push_back(cm);
    }
    for (int i = 0; i < n; ++i) {
        for (int y = 0; y < k; ++y) {
            if (y != ys[i]) cost_matrix[qs[i]][y] += 1;
        }
    }
    vector<int> assignment;
    HungarianAlgorithm HungAlgo;
    double cost = HungAlgo.Solve(cost_matrix, assignment);
    return 1 - cost / n;
}

std::vector<std::vector<double>> readData(string data_dir, int i, int T, int iter) {
    string fname = data_dir+"/D-"+std::to_string(iter)+"-"+std::to_string(i)+"-of-"+std::to_string(T-1);
    CSVReader reader(fname);
    return reader.getData();
}

std::vector<int> readLabels(string data_dir, int i, int T, int iter) {
    string fname = data_dir+"/y-"+std::to_string(iter)+"-"+std::to_string(i)+"-of-"+std::to_string(T-1);
    CSVReader reader(fname);
    return reader.getLabels();
}

PiecewiseConstantFunction piecewiseHamming(std::vector<std::pair<AlphaState, Interval>> &res,
  std::vector<int> y, std::vector<std::vector<double>> D, int k) {
    std::vector<ConstantFunction> pieces;
    for (auto r : res) {
        double alpha = (r.second.lower + r.second.upper)/2;
        std::vector<int> clusters = alphaSampling(D, r.first.centers);
        std::vector<int> y_copy={};
        for (auto a:y) {y_copy.push_back(a-1);}
        double h = hammingCost(k, y_copy, clusters);
        pieces.push_back(ConstantFunction(h,r.second.lower,r.second.upper));
    }
    PiecewiseConstantFunction u(&pieces);
    return u;
}

double mean(std::vector<double> list) {
    double sum = 0.0;
    for (auto l : list) {
        sum+=l;
    }
    return sum/list.size();
}

double stdev(std::vector<double> list) {
    double sum = 0.0;
    double m = mean(list);
    for (auto l : list) {
        sum+=(l-m)*(l-m);
    }
    return sqrt(sum/list.size());
}

int main(int argc, char *argv[]) {
// std::cout<<"segmentation fault happens before meta-training process";
std::string data_dir = argv[6];
srand(time(NULL)); 
int k = std::stoi(argv[7]);
double a_min = 0.0, a_max = 10.0; 
for (int iter = 1; iter < std::stoi(argv[1]); ++iter) {
    // printf("hello\n");
    FILE *fp = fopen("tar_omn_wrt_sample_size_per_task.txt","a+");
    Interval I(a_min,a_max);
    std::vector<double> thirty_shots;
    int num_iterations = std::stoi(argv[2]);
    int num_training_samples_per_task = std::stoi(argv[3]);
    for (int IT = 0; IT < num_iterations; ++IT) {
        
        PiecewiseConstantFunction histogram(I);
        for (int it = 0; it < iter; ++it){
            // printf("hello\n");
            int T = std::stoi(argv[3]);
            ExponentialForecaster<PiecewiseConstantFunction> ef1(/*R=*/(a_max-a_min),
            /*L=*/0,/*H=*/1,/*T=*/T,/*w=*/0.5,/*d=*/1,/*I=*/I);
            OfflineOptimalForecaster<PiecewiseConstantFunction> oo2(/*I=*/I);
            PiecewiseConstantFunction U(I);
            double extra_payoff = ef1.payoff();
            for (int i = 0; i < T; ++i) {
                // printf("hello\n");
                int j = i;
                std::vector<std::vector<double>> D = readData(data_dir, j, T, it);
                std::vector<int> y = readLabels(data_dir, j, T, it);
                std::vector<std::pair<AlphaState, Interval>> res(alphaSamplingExhaustive(D, /*k=*/k, a_min, a_max));
                // Compute a piecewise constant function for hamming cost of alpha.
                PiecewiseConstantFunction u = piecewiseHamming(res, y, D, k);
                double rho1 = ef1.Forecast(u);
                double opt_arg = oo2.Forecast(u);
                if (i==T-1)  thirty_shots.push_back((oo2.payoff()-ef1.payoff()+extra_payoff)/(T * iter));
                Add(&histogram, u);
            }
        }
    }
    printf("%d tasks, %d trials, %d samples per task, %d clusters, thirty_shots: %.5f\n", iter, num_iterations, num_training_samples_per_task, k, accumulate(thirty_shots.begin(), thirty_shots.end(),0.0)/num_iterations);
    fprintf(fp, "%d tasks, %d trials, %d samples per task, %d clusters, thirty_shots: %.5f\n", iter, num_iterations, num_training_samples_per_task, k, accumulate(thirty_shots.begin(), thirty_shots.end(),0.0)/num_iterations);
    fclose(fp);
    }
  return 0;
}