// based on the code from https://en.wikipedia.org/wiki/2-opt?ysclid=mfyyxs69w4580185422
// based on the code from https://github.com/ArGintum/RTD-Lite

#include <random>
#include <stdio.h>
#include <vector>
#include <iostream>
#include <string>
#include <cstdlib>
#include <fstream>
#include <chrono>
#include "rtdl.h"

using namespace std;
using namespace std::chrono;

double mean(const std::vector<double>& numbers) {
    double sum = 0.0;
    for (double num : numbers) {
        sum += num;
    }
    return numbers.empty() ? 0 : sum / numbers.size();
}

double standard_dev(const std::vector<double>& numbers) {
    double mean_value = mean(numbers);
    double sum = 0.0;
    for (double num : numbers) {
        sum += pow(num - mean_value, 2);
    }
    return numbers.empty() ? 0 : sqrt(sum / (numbers.size() - 1));
}

class Point {
public:
  double x, y;

  Point(double x, double y) {
    this->x = x;
    this->y = y;
  }
  Point() {
    this->x = 0.0;
    this->y = 0.0;
  }

  inline double dist(const Point &other) const {
    double diffx = x - other.x;
    double diffy = y - other.y;
    return sqrt(diffx * diffx + diffy * diffy);
  }
};

double pathLength(const vector<Point> &problem, const vector<int> &path) {
    int n = path.size();
    double length = problem[path[n - 1]].dist(problem[path[0]]);
    for (int i = 0; i < n - 1; i++) {
        length += problem[path[i]].dist(problem[path[i + 1]]);
    }
    return length;
}

void swap_edges(vector<int> &path, int i, int j) {
  while (i < j) {
    int temp = path[i];
    path[i] = path[j];
    path[j] = temp;
    i++;
    j--;
  }
}

void compute_distance_matrix(const vector<Point> &problem, vector<vector<double>> &distance_matrix) {
    int n = problem.size();
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < i; j++) {
            distance_matrix[j][i] = distance_matrix[i][j] = problem[i].dist(problem[j]);
        }
    }
}

class EdgeSelector {
    public:
        int n;
        vector<pair<vector<int>, double>> edge_penalty;
        RTD_Lite_TSP_Penalty rtdl_obj;

        EdgeSelector(const vector<vector<double>> &distance_matrix) : rtdl_obj(distance_matrix) {
            n = distance_matrix.size();
        }

        void get_rtdl_weights(const vector<int> &path) {
            edge_penalty.clear();
            rtdl_obj.run(path, edge_penalty);
            sort(edge_penalty.begin(), edge_penalty.end(), [](const auto &a, const auto &b) {return -a.second < -b.second;});
        }

        int get_points(const map<vector<int>, int> &edge2point, const vector<int> &edge) {
            auto p1 = edge2point.find(edge);

            if (p1 != edge2point.end()) {
                return p1->second;
            } else {
                p1 = edge2point.find(vector<int>{edge[1], edge[0]});
                if (p1 != edge2point.end()) {
                    return p1->second;
                } else {
                    return -1;
                }
            }
        }

        void get_sorted_edges(const vector<int> &path, vector<int> &plist) {
            map<vector<int>, int> edge2point;
            for (int i = 0; i < n; i++) {
                edge2point[{path[i], path[(i + 1) % n]}] = (i + 1) % n;
            }
            vector<bool> used_index(n, false);
            for (int i = 0; i < edge_penalty.size(); i++) {
                int value = get_points(edge2point, edge_penalty[i].first);
                if (value != -1) {
                    plist.push_back(value);
                    used_index[value] = true;
                }
            }
            for (int i = 0; i < n; i++) {
                if (used_index[i] == false) {
                    plist.push_back(i);
                }
            }
        }
};

int run_2_opt_rtdl(const vector<Point> &problem, vector<int> &path, int max_iterations, int epoch_freq, int progressive_step) {
    int n = path.size();
    double curLength = pathLength(problem, path);
    int iteration = 0;
    bool foundImprovement = true;

    vector<vector<double>> distance_matrix(n, vector<double>(n, 0));
    compute_distance_matrix(problem, distance_matrix);
    EdgeSelector edge_sel(distance_matrix);

    int opt_len = progressive_step;
    
    while ((foundImprovement && iteration < max_iterations) || opt_len < n) {
        foundImprovement = false;
        
        if (iteration % epoch_freq == 0) {
            edge_sel.get_rtdl_weights(path);
        }
        vector<int> vlist_raw;
        edge_sel.get_sorted_edges(path, vlist_raw);
        vector<int> vlist;
        for (const auto& x: vlist_raw) {
            if (x < opt_len) {
                vlist.push_back(x);
            }
        }
        // vector<vector<bool>> checked(n, vector<bool>(n, false));

        for (const auto& end_point : vlist) {
            for (const auto& v: vlist) {
                int i = min(v, end_point);
                int i_prev = (n + i - 1) % n;
                int j_next = max(v, end_point);
                int j = (n + j_next - 1) % n;
                if (i >= j) {
                    continue;
                }
                if (i_prev == j && i == j_next) {
                    continue;
                }
                double lengthDelta =
                -distance_matrix[path[i_prev]][path[i]] + distance_matrix[path[i_prev]][path[j]]
                -distance_matrix[path[j]][path[j_next]] + distance_matrix[path[i]][path[j_next]];

                if (lengthDelta < 0) {
                    swap_edges(path, i, j);
                    curLength += lengthDelta;
                    ++iteration;
                    foundImprovement = true;
                    break;
                }
            }
            if (foundImprovement) {
                break;
            }
        }
        if (foundImprovement == false && opt_len < n) {
            opt_len = opt_len + progressive_step;
            foundImprovement = true;
        }
    }
    // if (iteration >= max_iterations) {
    //     cout << "MAX ITERATIONS EXCEEDED" << endl;
    // }
    return iteration;
}

void load_init_tours(string filename, int problem_size, int num_problems, vector<vector<Point>> &problems, vector<vector<int>> &init_paths) {
    ifstream file (filename);
    double x, y;
    string sep;
    int index;
    if (file.is_open()) {
        for (int i = 0; i < num_problems; ++i) {
            for (int j = 0; j < problem_size; ++j) {
                file >> x >> y;
                problems[i][j] = Point(x, y);
            }
            file >> sep;
            for (int j = 0; j < problem_size; ++j) {
                file >> index;
                init_paths[i][j] = index;
            }
        }
        file.close();
    }
}

void print_configuration(string filename, int problem_size, int num_problems, int max_iterations, int epoch_freq, int progressive_step) {
    cout << "CONFIGURATION HYPERPARAMETERS" << endl;
    cout << "Filename: " << filename << endl;
    cout << "Problem size: " << problem_size << endl;
    cout << "Number of problems: " << num_problems << endl;
    cout << "Max iterations: " << max_iterations << endl;
    cout << "Epoch frequency: " << epoch_freq << endl;
    cout << "Progressive step: " << progressive_step << endl;
    cout << endl;
}

int main(int argc, char* argv[]) {
    // load greedy tours for all problems
    string tour_filename = argv[1];
    int problem_size = atoi(argv[2]);
    int num_problems = atoi(argv[3]);
    int max_iterations = atoi(argv[4]);
    int epoch_freq = atoi(argv[5]);
    int progressive_step = atoi(argv[6]);
    string save_filename_prefix = argv[7];

    string opt_tour_filename = save_filename_prefix + "_" + to_string(max_iterations) + "_" + to_string(epoch_freq) + "_" + to_string(progressive_step) + "_tour.txt";
    string dist_filename = save_filename_prefix + "_" + to_string(max_iterations) + "_" + to_string(epoch_freq) + "_" + to_string(progressive_step) + "_dist.txt";
    print_configuration(tour_filename, problem_size, num_problems, max_iterations, epoch_freq, progressive_step);

    vector<vector<Point>> problems(num_problems, vector<Point>(problem_size));
    vector<vector<int>> paths(num_problems, vector<int>(problem_size));
    load_init_tours(tour_filename, problem_size, num_problems, problems, paths);

    vector<double> init_len_vector(num_problems);
    vector<double> opt_len_vector(num_problems);
    vector<double> duration_vector(num_problems);
    vector<double> iter_vector(num_problems);

    auto start_time_total = high_resolution_clock::now();
    for (int i = 0; i < num_problems; ++i) {
        cout << "RUN " << i << endl;
        double init_len = pathLength(problems[i], paths[i]);
        init_len_vector[i] = init_len;
        
        auto run_start_time = high_resolution_clock::now();
        int run_iter = run_2_opt_rtdl(problems[i], paths[i], max_iterations, epoch_freq, progressive_step);
        auto run_stop_time = high_resolution_clock::now();
        auto run_duration = duration_cast<milliseconds>(run_stop_time - run_start_time);
        duration_vector[i] = run_duration.count();
        iter_vector[i] = run_iter;
        
        cout << "Time elapsed (milliseconds): " << run_duration.count() << endl;
        cout << "Num iterations: " << run_iter << endl;
        double opt_len = pathLength(problems[i], paths[i]);
        opt_len_vector[i] = opt_len;
        cout << init_len << " " << opt_len << endl;
        cout << endl;

    }
    auto stop_time_total = high_resolution_clock::now();
    auto duration_total = duration_cast<seconds>(stop_time_total - start_time_total);

    cout << "Avg. Init Tour Length: " << setprecision(6) << mean(init_len_vector) << " ± " << setprecision(6) << standard_dev(init_len_vector) << endl;
    cout << "Avg. Optimized Tour Length: " << setprecision(6) << mean(opt_len_vector) << " ± "<< setprecision(6) << standard_dev(opt_len_vector) << endl;
    cout << "Avg. Time (milliseconds): " << setprecision(6) << mean(duration_vector) << " ± " << setprecision(6) << standard_dev(duration_vector) << endl;
    cout << "Total Time (seconds): " << setprecision(6) << duration_total.count() << endl;
    cout << "Avg. Num Iterations: " << setprecision(6) << mean(iter_vector) << " ± "<< setprecision(6) <<  standard_dev(iter_vector) << endl;

    ofstream out_file;
    out_file.open(opt_tour_filename);
    if (out_file.is_open()) {
        for (int i = 0; i < num_problems; i++) {
            for (int j = 0; j < problem_size; j++) {
                out_file << paths[i][j] << " ";
            }
            out_file << "\n";
        }
    }
    out_file.close();
    
    return 0;
}
