#include "Operations.h"
#include <thread>
#include <mutex>
#include <omp.h> // OpenMP for parallelism

void sort_abs_cust(std::vector<int>& A, const Instance& instance, char order) {
    static std::random_device rd;
    static std::mt19937 gen(rd());

    if (order == '\0') {
        static std::vector<char> options = { 'R', 'D', 'F', 'C' };
        static std::discrete_distribution<> dist({ 4, 4, 2, 1 });
        order = options[dist(gen)];
    }

    if (order == 'R') {
        std::shuffle(A.begin(), A.end(), gen);
    }
    else if (order == 'D') {
        std::sort(A.begin(), A.end(), [&](int a, int b) {
            return instance.demand[a] > instance.demand[b];
            });
    }
    else if (order == 'F') {
        std::sort(A.begin(), A.end(), [&](int a, int b) {
            return instance.distanceMatrix[a][0] > instance.distanceMatrix[b][0];
            });
    }
    else {
        std::sort(A.begin(), A.end(), [&](int a, int b) {
            return instance.distanceMatrix[a][0] < instance.distanceMatrix[b][0];
            });
    }
}

std::tuple<Solution, std::vector<float>> ruin_recreate(Solution solution, int m, float beta, int n, float T, float c_bar, bool insertInNewToursOnly) {


    ModifiedSolution bestSol(solution);


    std::vector<float> costs;
    for (int i = 0; i < m; i++)
    {

        costs.push_back(std::numeric_limits<float>::infinity());
        ModifiedSolution msol(solution);


        float bestCosts = std::numeric_limits<float>::infinity();
        std::unordered_set<int> A = msol.destroy(c_bar, 10, 0.01);
        std::vector<int>removedCustomers(A.begin(), A.end());

        for (int j = 0; j < n; j++)
        {

            if (j > 0) {
                sort_abs_cust(removedCustomers, solution.instance, 'R');
            }

            ModifiedSolution msol_copy(msol);

            msol_copy.repair(removedCustomers, beta, insertInNewToursOnly);

            if (msol_copy.totalCosts < costs.back())
            {
                costs.back() = msol_copy.totalCosts;
            }

            if (msol_copy.totalCosts < bestCosts)
            {
                bestSol = msol_copy;
                bestCosts = msol_copy.totalCosts;
            }


        }

        float thresh = solution.totalCosts - T * std::log(getRandomFraction(0, 1));
        if (thresh > bestSol.totalCosts) {
            solution.acceptModifiedSolution(bestSol);
        }

    }



    return std::make_tuple(solution, costs);

}

std::vector<std::vector<int>> heuristic_deconstruction_selection(Solution & solution, float c_bar, int m) {
    std::vector<std::vector<int>> A;
    for (int i = 0; i < m; i++) {

        ModifiedSolution msol(solution);
        std::unordered_set<int> removedCustomers_set = msol.destroy(c_bar, 10, 0.01);
        std::vector<int>removedCustomers(removedCustomers_set.begin(), removedCustomers_set.end());
        A.push_back(removedCustomers);
    }
    return A;
}

std::tuple<Solution, std::vector<float>> remove_recreate_allImp(Solution solution, std::vector<std::vector<int>>& A, float beta, int n, float T, bool insertInNewToursOnly) {

    ModifiedSolution bestSol(solution);

    std::vector<float> costs(A.size());
    for (int i = 0; i < A.size(); i++)
    {

        ModifiedSolution msol(solution);
        msol.removeCustomers(A[i]);
        std::vector<int>& removedCustomers = A[i];

        float bestCosts = std::numeric_limits<float>::infinity();

        for (int j = 0; j < n; j++)
        {

            if (j > 0) {
                sort_abs_cust(removedCustomers, solution.instance, 'R');
            }

            ModifiedSolution msol_copy(msol);

            msol_copy.repair(removedCustomers, beta, insertInNewToursOnly);

            if (msol_copy.totalCosts < bestCosts)
            {
                bestSol = msol_copy;
                bestCosts = msol_copy.totalCosts;
            }

        }
        
        costs[i] = bestCosts;
        float thresh = solution.totalCosts - T * std::log(getRandomFraction(0,1));
        if (thresh > bestSol.totalCosts) {
           solution.acceptModifiedSolution(bestSol);
        }

    }

    return std::make_tuple(solution, costs); // TODO use a struct to return refs to objects
}

std::tuple<Solution, std::vector<float>> remove_recreate_singleImp(Solution solution, std::vector<std::vector<int>>& A, float beta, int n, bool accept_last_only, bool insertInNewToursOnly) {

    float bestCosts = std::numeric_limits<float>::infinity();
    ModifiedSolution bestSol(solution);

    std::vector<float> costs(A.size());
    for (int i = 0; i < A.size(); i++)
    {
        ModifiedSolution msol(solution);
        msol.removeCustomers(A[i]);
        std::vector<int>& removedCustomers = A[i];

        for (int j = 0; j < n; j++)
        {
            if (j > 0) {
                sort_abs_cust(removedCustomers, solution.instance, 'R');
            }

            ModifiedSolution msol_copy(msol);

            msol_copy.repair(removedCustomers, beta, insertInNewToursOnly);

            if (msol_copy.totalCosts < bestCosts)
            {
                bestSol = msol_copy;
                bestCosts = msol_copy.totalCosts;
            }

            // DELETE AT SOME FUTURE VERSION
            if (accept_last_only && i == A.size() - 1)
            {   
                // Overwrite best solution with the last one found
                bestSol = msol_copy;
            }

        }
        costs[i] = bestCosts;

    }

    solution.acceptModifiedSolution(bestSol);

    return std::make_tuple(solution, costs); // TODO use a struct to return refs to objects
}

std::tuple<Solution, std::vector<float>> remove_recreate_allImp_multi(Solution solution, std::vector<std::vector<int>>& A, float beta, int n, float T, bool insertInNewToursOnly, int num_processes) {
    int num_tasks = A.size();
    std::vector<float> costs(num_tasks, std::numeric_limits<float>::infinity());
    std::vector<ModifiedSolution> bestSols(num_tasks, ModifiedSolution(solution));

    float globalBestCost = std::numeric_limits<float>::infinity();
    ModifiedSolution globalBestSol(solution);
    int globalBestIdx = -1;

    // Parallel for loop with OpenMP
    #pragma omp parallel for num_threads(num_processes) schedule(static)
    for (int i = 0; i < num_tasks; ++i) {
        ModifiedSolution bestSol(solution);
        ModifiedSolution msol(solution);
        msol.removeCustomers(A[i]);
        std::vector<int>& removedCustomers = A[i];
        float bestCosts = std::numeric_limits<float>::infinity();
        for (int j = 0; j < n; j++) {
            if (j > 0) {
                sort_abs_cust(removedCustomers, solution.instance, 'R');
            }
            ModifiedSolution msol_copy(msol);
            msol_copy.repair(removedCustomers, beta, insertInNewToursOnly);
            if (msol_copy.totalCosts < bestCosts) {
                bestSol = msol_copy;
                bestCosts = msol_copy.totalCosts;
            }
        }
        costs[i] = bestCosts;
        bestSols[i] = bestSol;
        // Use OpenMP critical to update global best
        #pragma omp critical
        {
            if (bestCosts < globalBestCost) {
                globalBestCost = bestCosts;
                globalBestSol = bestSol;
                globalBestIdx = i;
            }
        }
    }

    float thresh = solution.totalCosts - T * std::log(getRandomFraction(0,1));
    if (thresh > globalBestSol.totalCosts) {
        solution.acceptModifiedSolution(globalBestSol);
    }
    return std::make_tuple(solution, costs);
}

std::vector<std::pair<Solution, std::vector<float>>> remove_recreate_allImp_batch(
    std::vector<Solution> solutions,
    std::vector<std::vector<std::vector<int>>> A_batch,
    float beta, int n, float T, bool insertInNewToursOnly, int num_processes)
{
    int batch_size = solutions.size();
    std::vector<std::pair<Solution, std::vector<float>>> results;
    results.reserve(batch_size);

    #pragma omp parallel num_threads(num_processes)
    {
        std::vector<std::pair<Solution, std::vector<float>>> local_results;
        #pragma omp for schedule(dynamic) nowait
        for (int i = 0; i < batch_size; ++i) {
            auto result_tuple = remove_recreate_allImp(solutions[i], A_batch[i], beta, n, T, insertInNewToursOnly);
            local_results.emplace_back(std::get<0>(result_tuple), std::get<1>(result_tuple));
        }
        #pragma omp critical
        results.insert(results.end(), local_results.begin(), local_results.end());
    }
    return results;
}

