#include "Solution.h"
#include <unordered_map>


Solution::Solution(const Instance& instance) : instance(instance)
{
    // Initialize totalCosts to zero
    totalCosts = 0.0;

    // Create a tour for each customer
    for (int idx = 1; idx <= instance.numCustomers; ++idx) {
        Tour t;

        // Start each tour with a single node (customer)
        t.nodes.push_back(idx);

        // Update the cost and demand of the tour based on the instance
        t.updateCostAndDemand(instance);

        // Add the tour to the solution
        tours.push_back(t);

        // Increment the totalCosts with the cost of the new tour
        totalCosts += t.costs;
    }

    totalCosts += instance.total_prizes;

    generateCustomerToTourMap();
}

Solution::Solution(const Instance& instance, const std::vector<std::vector<int>>& tours) : instance(instance)
{
    // Initialize totalCosts to zero
    totalCosts = 0.0;

    // Iterate through the provided list of tours using a reference
    for (const auto& nodes : tours) {
        Tour t;

        // Copy the nodes from the input tours to the new tour
        t.nodes = nodes;

        // Update the cost and demand of the tour based on the instance
        t.updateCostAndDemand(instance);

        // Add the tour to the solution
        this->tours.push_back(t);

        // Increment the totalCosts with the cost of the new tour
        totalCosts += t.costs;
    }

    totalCosts += instance.total_prizes;

    generateCustomerToTourMap();
}

void Solution::generateCustomerToTourMap() {

    customerToTourMap.resize(instance.numNodes);

    for (int i = 0; i < instance.numCustomers; ++i) {
        customerToTourMap[i] = -1;
    }

    // Update pointers to tours using range-based for loops
    for (size_t tourIndex = 0; tourIndex < tours.size(); ++tourIndex) {
        const auto& t = tours[tourIndex];
        for (const auto& c : t.nodes) {
            customerToTourMap[c] = tourIndex;
        }
    }
}



void Solution::acceptModifiedSolution(ModifiedSolution & modSol) {
        totalCosts = modSol.totalCosts;

    // Remove tours not present in oldTours from tours
    //tours.erase(std::remove_if(tours.begin(), tours.end(), [&](const Tour& t) {
    //    return std::find(modSol.oldTours.begin(), modSol.oldTours.end(), &t) == modSol.oldTours.end();
    //    }), tours.end());

        std::sort(modSol.removedToursId.rbegin(), modSol.removedToursId.rend());

        // Step 4: Remove the elements based on the indices vector
        for (size_t index : modSol.removedToursId) {
            auto& removedTour = tours[index];
            removedTour = std::move(tours.back());
            tours.pop_back();

            for (const int c : removedTour.nodes) {
                customerToTourMap[c] = index; 
            }
        }

        // Capture the starting point of new tours
        auto startNewTours = tours.size();

        // Move newTours to tours
        tours.insert(tours.end(), modSol.newTours.begin(), modSol.newTours.end());

        // Iterate over the new tours
        for (size_t tourIndex = startNewTours; tourIndex < tours.size(); ++tourIndex) {
            const auto& t = tours[tourIndex];
            for (const int c : t.nodes) { // Assuming nodes is a member of the Tour struct or class
                customerToTourMap[c] = tourIndex;
            }
        }

        for (int c : modSol.nonInsertedNodes) {
            customerToTourMap[c] = -1;
        }

}

std::vector<std::vector<int>> Solution::getTourList() const {
    std::vector<std::vector<int>> nodeList;
    nodeList.reserve(tours.size());  // Reserve space for efficiency

    for (const auto& tour : tours) {
        nodeList.push_back(tour.nodes);
    }

    return nodeList;
}

std::size_t Solution::getHash()
{
    std::size_t hash = 0;
    for (auto& tour : tours) {
        std::size_t seed = tour.nodes.size();
        for (auto& i : tour.nodes) {
            seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        }
        hash += seed;
    }
    return hash;
}

Solution& Solution::operator=(const Solution & other)
{
    totalCosts = other.totalCosts;
    tours = other.tours;
    customerToTourMap = other.customerToTourMap;   
    return *this;
}





ModifiedSolution::ModifiedSolution(Solution& originalSolution) : originalSolution(originalSolution), instance(originalSolution.instance) {

    totalCosts = originalSolution.totalCosts;
}

void ModifiedSolution::removeCustomers(const std::vector<int>& A) {
    std::unordered_map<int, std::vector<int>> tourModifications;

    totalCosts = 0;

    // Populate the map with customers to be removed per tour
    for (int customer : A) {
        int tourIndex = originalSolution.customerToTourMap[customer];
        if (tourIndex >= 0) {
            tourModifications[tourIndex].push_back(customer);
        }
    }

    // Apply removals and update tours
    for (const auto& entry : tourModifications) {
        int tourIndex = entry.first;
        const std::vector<int>& customersToRemove = entry.second;

        Tour& tour = originalSolution.tours[tourIndex];
        Tour newTour;
        // Reserve space in advance to avoid multiple allocations
        newTour.nodes.reserve(tour.nodes.size() - customersToRemove.size());

        // Remove specified customers from the tour
        for (int node_id : tour.nodes) {
            if (std::find(customersToRemove.begin(), customersToRemove.end(), node_id) == customersToRemove.end()) {
                newTour.nodes.push_back(node_id);
            }
        }

        if (!newTour.nodes.empty()) {
            newTour.updateCostAndDemand(instance);
            totalCosts += newTour.costs;
            newTours.push_back(newTour);
        }
        removedToursId.push_back(tourIndex);
    }

    // Include the costs of unaffected tours
    for (size_t i = 0; i < originalSolution.tours.size(); ++i) {
        if (tourModifications.find(i) == tourModifications.end()) {
            totalCosts += originalSolution.tours[i].costs;
        }
    }

    totalCosts += instance.total_prizes;
}


std::unordered_set<int> ModifiedSolution::destroy(float c_bar, int L_max, float alpha) {
    std::unordered_set<int> A;
    std::unordered_set<Tour*> R;

    float avg_tour_cardinality = 0;
    for (auto& t : originalSolution.tours) { avg_tour_cardinality += t.nodes.size(); }
    avg_tour_cardinality /= originalSolution.tours.size();

    int ls_max = std::min(L_max, static_cast<int>(avg_tour_cardinality));

    float ks_max = (4 * c_bar) / (1 + ls_max) - 1;
    int ks = static_cast<int>(getRandomFraction(1, ks_max + 0.9999));

    int seed_c = getRandomNumber(1, instance.numCustomers);


    for (int c : instance.adj[seed_c]) {

        if (c == 0) {  // is depot
            continue;
        }

        // MOD FOR PCVRP (abort if enough customers are removed)
        if (R.size() >= ks || A.size() > c_bar * 2) {
            break;
        }

        // MOD FOR PCVRP (customers without a tour a directley added)
        if (originalSolution.customerToTourMap[c] == -1)
        {
            // The customer is not part of a tour
            A.insert(c);
        }

        Tour* c_tour = &originalSolution.tours[originalSolution.customerToTourMap[c]];


        if (A.count(c) == 0 && R.count(c_tour) == 0) {
            int c_star = c;

            int c_tour_card = c_tour->nodes.size();
            int lt_max = std::min(c_tour_card, ls_max);

            int lt = static_cast<int>(getRandomFraction(1, lt_max + 0.9999));


            std::vector<int> removed_cust;
            Tour newTour;

            if (lt < 2 || lt == lt_max || getRandomFractionFast() < 0.5) {
                c_tour->stringRemoval(removed_cust, newTour, lt, c_star);
            }
            else {
                c_tour->splitStringRemoval(removed_cust, newTour, lt, c_star, alpha);
            }



            A.insert(removed_cust.begin(), removed_cust.end());
            R.insert(c_tour);

            if (newTour.nodes.size() > 0) {
                newTour.updateCostAndDemand(instance);
                newTours.push_back(newTour);
            }

        }
    }

    totalCosts = 0;
    for (int i = originalSolution.tours.size() - 1; i >= 0; --i) {
        Tour* tour = &originalSolution.tours[i];
        if (R.count(tour) == 0)
        {
            totalCosts += tour->costs;
        }
        else {
            removedToursId.push_back(i);
        }
    }
    for (auto& t : newTours)
    {
        totalCosts += t.costs;
    }
    totalCosts += instance.total_prizes;

    return A;
}


void ModifiedSolution::repair(const std::vector<int>& A, float beta, bool insertInNewToursOnly) {
    const int& capacity = instance.vehicleCapacity;

    int best_original_tour_idx;
    int best_new_tour_idx;
    int best_ins_pos;
    int next_node;
    float insertionCosts;
    float bestInsertionCost;
    std::vector<int> oldTourIds;
    nonInsertedNodes.clear();

    if (!insertInNewToursOnly) {
        for (int t_idx = 0; t_idx < originalSolution.tours.size(); ++t_idx) {
            if (std::find(removedToursId.begin(), removedToursId.end(), t_idx) == removedToursId.end()) {
                oldTourIds.push_back(t_idx);
            }
        }
    }


    for (int c : A) {
        bestInsertionCost = std::numeric_limits<float>::infinity();
        const int& c_demand = instance.demand[c];
        best_original_tour_idx = -1;
        best_new_tour_idx = -1;

        if (!insertInNewToursOnly) {
            for (int t_idx : oldTourIds) {
                const auto& tour = originalSolution.tours[t_idx];

                if (tour.demand + c_demand <= capacity) {

                    int prev_node = 0;
                    for (int new_pos = 0; new_pos < tour.nodes.size(); ++new_pos) {
                        next_node = tour.nodes[new_pos];

                        insertionCosts = instance.distanceMatrix[prev_node][c] + instance.distanceMatrix[c][next_node] - instance.distanceMatrix[prev_node][next_node] - instance.prizes[c];;

                        prev_node = next_node;

                        if (insertionCosts < bestInsertionCost) {
                            if (getRandomFractionFast() < (1 - beta)) {
                                best_original_tour_idx = t_idx;
                                best_ins_pos = new_pos;
                                bestInsertionCost = insertionCosts;
                            }
                        }
                    }

                    insertionCosts = instance.distanceMatrix[prev_node][c] + instance.distanceMatrix[c][0] - instance.distanceMatrix[prev_node][0] - instance.prizes[c];;

                    if (insertionCosts < bestInsertionCost) {
                        if (getRandomFractionFast() < (1 - beta)) {
                            best_original_tour_idx = t_idx;
                            best_ins_pos = tour.nodes.size();
                            bestInsertionCost = insertionCosts;
                        }
                    }

                }
            }
        }


        for (int t_idx = 0; t_idx < newTours.size(); ++t_idx) {
            const auto& tour = newTours[t_idx];

            if (tour.demand + c_demand <= capacity) {

                int prev_node = 0;
                for (int new_pos = 0; new_pos < tour.nodes.size(); ++new_pos) {
                    next_node = tour.nodes[new_pos];

                    insertionCosts = instance.distanceMatrix[prev_node][c] + instance.distanceMatrix[c][next_node] - instance.distanceMatrix[prev_node][next_node] - instance.prizes[c];

                    prev_node = next_node;

                    if (insertionCosts < bestInsertionCost) {
                        if (getRandomFractionFast() < (1 - beta)) {
                            best_new_tour_idx = t_idx;
                            best_ins_pos = new_pos;
                            bestInsertionCost = insertionCosts;
                        }
                    }
                }

                insertionCosts = instance.distanceMatrix[prev_node][c] + instance.distanceMatrix[c][0] - instance.distanceMatrix[prev_node][0] - instance.prizes[c];

                if (insertionCosts < bestInsertionCost) {
                    if (getRandomFractionFast() < (1 - beta)) {
                        best_new_tour_idx = t_idx;
                        best_ins_pos = tour.nodes.size();
                        bestInsertionCost = insertionCosts;
                    }
                }

            }
        }

        if (bestInsertionCost < 0) {
            if (best_new_tour_idx != -1) {
                auto& tour = newTours[best_new_tour_idx];
                tour.nodes.insert(tour.nodes.begin() + best_ins_pos, c);
                tour.demand += c_demand;
                tour.costs += bestInsertionCost;
                totalCosts += bestInsertionCost;

            }
            else if (best_original_tour_idx != -1) {
                Tour newTour = originalSolution.tours[best_original_tour_idx];
                removedToursId.push_back(best_original_tour_idx);
                oldTourIds.erase(std::remove(oldTourIds.begin(), oldTourIds.end(), best_original_tour_idx), oldTourIds.end());

                newTour.nodes.insert(newTour.nodes.begin() + best_ins_pos, c);
                newTour.demand += c_demand;
                newTour.costs += bestInsertionCost;

                newTours.push_back(newTour);
                totalCosts += bestInsertionCost;

            }
        }
        else if ((instance.distanceMatrix[0][c] + instance.distanceMatrix[c][0] - instance.prizes[c] < 0) ||
            getRandomFractionFast() < 0.5) {
            Tour newTour;
            newTour.nodes = { c };
            newTour.costs = instance.distanceMatrix[0][c] + instance.distanceMatrix[c][0] - instance.prizes[c];
            newTour.demand = c_demand;
            totalCosts += newTour.costs;
            newTours.push_back(newTour);
        }
        else {
            nonInsertedNodes.push_back(c);
        }
    }
}


ModifiedSolution& ModifiedSolution::operator=(const ModifiedSolution& other) {
    if (this != &other) {
        // Copy values from 'other' to 'this'
        totalCosts = other.totalCosts;
        newTours = other.newTours;
        removedToursId = other.removedToursId;
        nonInsertedNodes = other.nonInsertedNodes;
    }
    return *this;
}
