// cppimport

#include <vector>
#include <stdexcept>
#include <iostream>
#include <random>
#include <chrono>
#include <algorithm>
#include <tuple>
#include <unordered_set>
#include <iterator>
#include <numeric> 
#include <fstream>
#include <string>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;

#include "Instance.h"
#include "Utils.h"
#include "Solution.h"
#include "Tour.h"
#include "Operations.h"





std::tuple<Solution, float, int> search(const Instance& instance, int iterations, int timelimit, double T_0, double T_f) {


    std::chrono::steady_clock::time_point t_start = std::chrono::steady_clock::now();
    float incumbenTime = -1;

    Solution sol = Solution(instance);

    if (timelimit > 0) {
        iterations = std::pow(10.0, 9.0);
    }

    double T = T_0;
    double c = std::pow(T_f / T_0, 1.0 / iterations);
    Solution incumbent = sol;

    int i;
    for (i = 0; i < iterations; ++i) {
        ModifiedSolution msol(sol);

        std::unordered_set<int> A = msol.destroy(10, 10, 0.01);
        std::vector<int>A_ordered(A.begin(), A.end());
        sort_abs_cust(A_ordered, sol.instance);
        msol.repair(A_ordered, 0.01, false);

        if (msol.totalCosts < sol.totalCosts - T * std::log(getRandomFraction()) )
        {
            sol.acceptModifiedSolution(msol);

            if (sol.totalCosts < incumbent.totalCosts)
            {
                incumbent = sol;
                //std::cout << sol.totalCosts << std::endl;
                incumbenTime = std::chrono::duration_cast<std::chrono::seconds> (std::chrono::steady_clock::now() - t_start).count() / 60.0;
            }

        }

        if (i % 1000000 == 0) {
            std::cout << "Best " << incumbent.totalCosts << " Current "  << sol.totalCosts << " Time Best " << incumbenTime << std::endl;
        }

        if (timelimit > 0) {
            if (i % 10000 == 0) {
                int runtime = std::chrono::duration_cast<std::chrono::seconds> (std::chrono::steady_clock::now() - t_start).count();
                //T = T_0 - ((T_0 - T_f) * (runtime / (float) timelimit));
                T = T_f * std::pow((T_0/T_f), (1 - (runtime/(float) timelimit)));

                if (runtime >= timelimit) {
                    break;
                }
            }
        }
        else {
            T = c * T;
        }

    }

    return std::make_tuple(sol, incumbenTime, i);
}

Solution create_starting_solution(const Instance& instance, int nbImprovement, int nbDestroy) {
    static std::random_device rd;
    static std::mt19937 gen(rd());

    Solution sol = Solution(instance);
    std::vector<float> costs;

    // Generate a vector containing all node indices
    std::vector<int> allNodes(instance.numCustomers);
    std::iota(allNodes.begin(), allNodes.end(), 1);

    for (int i = 0; i < nbImprovement; ++i) {
        // Shuffle the vector to get random nodes without replacement
        std::shuffle(allNodes.begin(), allNodes.end(), gen);

        // Select the first nbDestroy nodes after shuffling
        std::vector<int> nodesToRemove(allNodes.begin(), allNodes.begin() + nbDestroy);

        std::vector<std::vector<int>> A = {nodesToRemove};

        // Assuming remove_recreate_allImp function signature is:
        // std::pair<Solution, std::vector<float>> remove_recreate_allImp(const Solution& sol, const std::vector<std::vector<int>>& A, float param1, int param2, int param3);
        std::tie(sol, costs) = remove_recreate_allImp(sol, A, 0.0, 1, 0, false);
    }

    return sol;
}


void run_sisr(char* path) {
    std::cout << "Solving instance: " << path << "\n";


    const int minProblemSize = 100;
    const int maxProblemSize = 1000;

    // Given values for it(100) and it(1000)
    const unsigned int itMin = 1e7; //3e7
    const unsigned int itMax = 1e8; //3e8

    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    //std::cout << "Solution created!" << sol.totalCosts << "\n";

    float cost_sum = 0;
    float time_sum = 0;

    for (int z = 0; z < 1; ++z) {
        //Instance dummy = createDummyInstance();
        Instance inst = Instance(path, false); //C:/Users/Andre/projects/pomo_sisrs/CVRP/instances/X/X-n344-k43.vrp
        int problemSize = inst.numCustomers;




        // Ensure the problemSize is within the specified range
        if (problemSize < minProblemSize)
            problemSize = minProblemSize;
        else if (problemSize > maxProblemSize)
            problemSize = maxProblemSize;

        // Linear interpolation formula
        unsigned int iterations = itMin + (itMax - itMin) * ((problemSize - minProblemSize) / (float)(maxProblemSize - minProblemSize));


        float time;
        int nb_iterations;
        Solution sol(inst);
        std::tie(sol, time, nb_iterations) = search(inst, iterations, 0, 100, 1);
        float cost = sol.totalCosts;
        //Solution sol = Solution(inst)

        //for (int k = 0; k < 1000; ++k) {
        //    std::vector<float> costs;
        //    std::tie(sol, costs) = ruin_recreate(sol, 100, 0.01, 10);
        //
        //    if (k % 100 == 0) {
        //        std::cout << "Best " << sol.totalCosts << " Current " << sol.totalCosts << std::endl;
        //    }
        //}



        //cost_sum += cost;
        //time_sum += time;

        std::cout << "Avg Cost: " << (cost_sum / ((float)z + 1.0)) << " Avg Time: " << (time_sum / ((float)z + 1.0)) << std::endl;

    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();

    std::cout << "Time difference = " << std::chrono::duration_cast<std::chrono::seconds> (end - begin).count() << "[s]" << std::endl;
}

void test_iterated_repair(char* path) {
    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    //std::cout << "Solution created!" << sol.totalCosts << "\n";

    for (int z = 0; z < 10; ++z) {
        Instance inst = Instance(path, true);
        std::cout << "Solving instance: " << path << "\n";
        Solution sol = create_starting_solution(inst, 30, 15);


        for (int i = 0; i < 1000; ++i) {
            std::vector<std::vector<int>> A;
            for (int k = 0; k < 200; ++k) {
                std::unordered_set<int> a;
                while (a.size() < 15) {
                    a.insert(getRandomNumber(1, inst.numCustomers));
                }
                std::vector<int> a_v(a.begin(), a.end());
                A.push_back(a_v);
            }

            std::vector<float> costs;
            std::tie(sol, costs) = remove_recreate_allImp(sol, A, 0.01, 5, 10);

            //Check tour
            float totalCosts = 0;
            for (auto& t : sol.tours) {
                float costs = t.costs;
                int demand = t.demand;
                t.updateCostAndDemand(inst);
                if (std::abs(t.costs - costs) > 0.0001) {
                    std::cout << "Error: Tour costs";
                }
                totalCosts += costs;


                float time = 0;
                int last_c = 0;
                for (auto& c : t.nodes) {
                    time += inst.distanceMatrix[last_c][c];
                    if (time > inst.endTW[c])
                    {
                        std::cout << "Error: TW";
                    }
                    time = std::max(time, inst.startTW[c]) + inst.serviceTime[c];
                    last_c = c;
                }


            }
            if (std::abs(totalCosts - sol.totalCosts) > 0.0001) {
                std::cout << "Error: Solution costs";
            }
            int visited_nodes = 0;
            for (auto& t : sol.tours) {
                visited_nodes += t.nodes.size();
            }
            if (visited_nodes != inst.numCustomers) {
                std::cout << "Error: NumCust";
            }



            //std::cout << "Solution improved!" << sol.totalCosts << " " << "Tours:" << sol.tours.size() << "\n";
        }
        std::cout << sol.totalCosts << "\n";
    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();

    std::cout << "Runtime: " << std::chrono::duration_cast<std::chrono::milliseconds> (end - begin).count() << "[ns]" << std::endl;
}


int main(int argc, char* argv[])
{

    if (argc >= 2 && strcmp(argv[1], "sisr") == 0)
    {
        run_sisr(argv[2]);
    }
    else if (argc >= 2 && strcmp(argv[1], "test") == 0)
    {
        test_iterated_repair(argv[2]);
    }

    return 0;

}

PYBIND11_MODULE(SISRs, m) {
    m.doc() = "SISRs"; // optional module docstring

#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif

    py::class_<Instance>(m, "Instance")
        .def(py::init<int, int, const std::vector<int>&, const std::vector<float>&, const std::vector<float>&, const std::vector<float>&, const std::vector<std::vector<float>>&>())
        .def_readwrite("numNodes", &Instance::numNodes)
        .def_readwrite("numCustomers", &Instance::numCustomers)
        .def_readwrite("vehicleCapacity", &Instance::vehicleCapacity)
        .def_readwrite("distanceMatrix", &Instance::distanceMatrix);


    py::class_<Solution>(m, "Solution")
    .def(py::init<Instance, const std::vector<std::vector<int>>&>())
    .def_readwrite("totalCosts", &Solution::totalCosts)
    .def("getTourList", &Solution::getTourList)
    .def("getHash", &Solution::getHash);

    m.def("create_starting_solution", &create_starting_solution, "create_starting_solution", py::return_value_policy::take_ownership);
    m.def("remove_recreate_allImp", &remove_recreate_allImp, "Remove and Recreate function", py::return_value_policy::take_ownership);
    m.def("remove_recreate_singleImp", &remove_recreate_singleImp, "Remove and Recreate function", py::return_value_policy::take_ownership);
    m.def("heuristic_deconstruction_selection", &heuristic_deconstruction_selection, "Remove and Recreate function", py::return_value_policy::take_ownership);
    m.def("search", &search, "Start default SISR search", py::return_value_policy::take_ownership);
    //m.def("permutate", &permutate, "permutate", py::return_value_policy::take_ownership);


}

/*
<%
setup_pybind11(cfg)
cfg['extra_compile_args'] = ['-O2']
cfg['sources'] = ['Instance.cpp', 'Operations.cpp', 'Solution.cpp', 'Tour.cpp', 'Utils.cpp']
%>
*/