#include "mmas.h"

#include <algorithm>
#include <atomic>
#include <barrier>
#include <chrono>
#include <cmath>
#include <limits>
#include <random>
#include <thread>

#include "local_search.h"

// -----------------------------
// Internal helpers for MMAS construction and pheromone updates
// -----------------------------
namespace {

struct AntWorkspace {
    std::vector<int> tour;
    std::vector<char> visited;
    std::vector<int> choices;
    std::vector<double> weights;
    LSWorkspace ls;

    void init(int n) {
        tour.assign(n, -1);
        visited.assign(n, 0);
        choices.resize(n);
        weights.resize(n);
    }
};

static inline double pow_alpha(double tau, double alpha) {
    if (alpha == 1.0) {
        return tau;
    }
    if (alpha == 2.0) {
        return tau * tau;
    }
    return std::pow(tau, alpha);
}

// Precompute (tau^alpha * eta) on candidate edges to avoid recomputation during construction
static void compute_candidate_total(const TriMatrix &pheromone,
                                    const CandidateList &cand,
                                    double alpha,
                                    std::vector<float> &total) {
    int n = cand.n;
    int k = cand.k;
    std::size_t size = static_cast<std::size_t>(n) * k;
    if (total.size() != size) {
        total.resize(size);
    }
    for (int i = 0; i < n; ++i) {
        const int *row = cand.row(i);
        const float *row_inv = cand.row_inv(i);
        float *row_total = total.data() + static_cast<std::size_t>(i) * k;
        for (int c = 0; c < k; ++c) {
            int node = row[c];
            double tau = pheromone.get(i, node);
            double w = pow_alpha(tau, alpha) * row_inv[c];
            row_total[c] = static_cast<float>(w);
        }
    }
}

// Construct an ant tour using the candidate list and q0 decision rule
static void build_tour(const Instance &inst,
                       const DistanceProvider &dist,
                       const CandidateList &cand,
                       const std::vector<float> &total,
                       const TriMatrix &pheromone,
                       const Config &cfg,
                       int nn_ants,
                       bool sparse_only,
                       AntWorkspace &ws,
                       std::mt19937_64 &rng,
                       std::uniform_real_distribution<double> &u01) {
    int n = inst.n;
    if (static_cast<int>(ws.tour.size()) != n) {
        ws.init(n);
    }
    std::fill(ws.visited.begin(), ws.visited.end(), 0);

    std::uniform_int_distribution<int> start_dist(0, n - 1);
    int current = start_dist(rng);
    ws.tour[0] = current;
    ws.visited[current] = 1;

    nn_ants = std::min(nn_ants, cand.k);

    auto choose_best_next = [&](int cur) -> int {
        int best = -1;
        double best_val = -1.0;
        if (sparse_only) {
            const int *row = cand.row(cur);
            const float *row_inv = cand.row_inv(cur);
            for (int c = 0; c < cand.k; ++c) {
                int node = row[c];
                if (ws.visited[node]) {
                    continue;
                }
                double tau = pheromone.get(cur, node);
                double val = pow_alpha(tau, cfg.alpha) * row_inv[c];
                if (val > best_val) {
                    best_val = val;
                    best = node;
                }
            }
            if (best != -1) {
                return best;
            }
        } else {
            for (int node = 0; node < n; ++node) {
                if (ws.visited[node]) {
                    continue;
                }
                double tau = pheromone.get(cur, node);
                double eta = dist.inv_dist_beta(cur, node, cfg.beta);
                double val = pow_alpha(tau, cfg.alpha) * eta;
                if (val > best_val) {
                    best_val = val;
                    best = node;
                }
            }
            if (best != -1) {
                return best;
            }
        }
        // Sparse-mode fallback: fill feasible edges with true distances
        double best_d = std::numeric_limits<double>::max();
        int best_node = -1;
        for (int node = 0; node < n; ++node) {
            if (ws.visited[node]) {
                continue;
            }
            double d = dist.distance(cur, node);
            if (d < best_d) {
                best_d = d;
                best_node = node;
            }
        }
        return best_node;
    };

    for (int step = 1; step < n; ++step) {
        const int *row = cand.row(current);
        const float *row_total = total.data() + static_cast<std::size_t>(current) * cand.k;

        double sum = 0.0;
        int best_node = -1;
        double best_weight = -1.0;
        int count = 0;

        for (int c = 0; c < nn_ants; ++c) {
            int node = row[c];
            if (ws.visited[node]) {
                continue;
            }
            double w = row_total[c];
            ws.choices[count] = node;
            ws.weights[count] = w;
            sum += w;
            if (w > best_weight) {
                best_weight = w;
                best_node = node;
            }
            count++;
        }

        int next = -1;
        if (cfg.q0 > 0.0 && u01(rng) < cfg.q0) {
            if (best_node >= 0) {
                next = best_node;
            } else {
                next = choose_best_next(current);
            }
        } else {
            if (sum <= 0.0 || count == 0) {
                next = choose_best_next(current);
            } else {
                double r = u01(rng) * sum;
                double acc = 0.0;
                for (int i = 0; i < count; ++i) {
                    acc += ws.weights[i];
                    if (acc >= r) {
                        next = ws.choices[i];
                        break;
                    }
                }
            }
        }
        if (next == -1) {
            // Should not happen; fallback to a random unvisited node
            count = 0;
            for (int node = 0; node < n; ++node) {
                if (!ws.visited[node]) {
                    ws.choices[count++] = node;
                }
            }
            std::uniform_int_distribution<int> pick(0, count - 1);
            next = ws.choices[pick(rng)];
        }

        ws.tour[step] = next;
        ws.visited[next] = 1;
        current = next;
    }
}

static double nearest_neighbor_tour_length(const Instance &inst,
                                           const DistanceProvider &dist,
                                           const CandidateList &cand,
                                           int nn_ls,
                                           bool use_local_search,
                                           bool use_dlb,
                                           int max_2opt_passes,
                                           std::mt19937_64 &rng,
                                           LSWorkspace &ls_ws) {
    // ACOTSP uses nearest-neighbor construction and runs one 2-opt when local search is enabled
    std::vector<int> tour(inst.n, -1);
    std::vector<char> visited(inst.n, 0);
    std::uniform_int_distribution<int> start_dist(0, inst.n - 1);
    int current = start_dist(rng);
    tour[0] = current;
    visited[current] = 1;

    for (int i = 1; i < inst.n; ++i) {
        double best_d = std::numeric_limits<double>::max();
        int best_node = -1;
        for (int node = 0; node < inst.n; ++node) {
            if (visited[node]) {
                continue;
            }
            double d = dist.distance(current, node);
            if (d < best_d) {
                best_d = d;
                best_node = node;
            }
        }
        tour[i] = best_node;
        visited[best_node] = 1;
        current = best_node;
    }

    if (use_local_search) {
        two_opt_first(dist, cand, nn_ls, use_dlb, max_2opt_passes, tour, rng, ls_ws);
    }
    return tour_length(dist, tour);
}

static double node_branching(const TriMatrix &pheromone,
                             const CandidateList &cand,
                             int nn_ants,
                             double lambda) {
    int n = cand.n;
    if (nn_ants <= 0) {
        return 0.0;
    }
    double avg = 0.0;
    int start = (nn_ants > 1) ? 1 : 0;
    for (int m = 0; m < n; ++m) {
        const int *row = cand.row(m);
        double min_val = pheromone.get(m, row[start]);
        double max_val = min_val;
        for (int i = start; i < nn_ants; ++i) {
            double val = pheromone.get(m, row[i]);
            min_val = std::min(min_val, val);
            max_val = std::max(max_val, val);
        }
        double cutoff = min_val + lambda * (max_val - min_val);
        double count = 0.0;
        for (int i = 0; i < nn_ants; ++i) {
            if (pheromone.get(m, row[i]) > cutoff) {
                count += 1.0;
            }
        }
        avg += count;
    }
    return avg / (static_cast<double>(n) * 2.0);
}

static void mmas_evaporation_nn_list(TriMatrix &pheromone,
                                     const CandidateList &cand,
                                     int nn_ants,
                                     double rho,
                                     double tau_min) {
    // Note: TriMatrix is symmetric; this is equivalent to evaporating directed candidate edges one by one
    int n = cand.n;
    for (int i = 0; i < n; ++i) {
        const int *row = cand.row(i);
        for (int j = 0; j < nn_ants; ++j) {
            int node = row[j];
            float val = pheromone.get(i, node);
            val = static_cast<float>((1.0 - rho) * val);
            if (val < tau_min) {
                val = static_cast<float>(tau_min);
            }
            pheromone.set(i, node, val);
        }
    }
}

static void check_pheromone_trail_limits(TriMatrix &pheromone,
                                         double tau_min,
                                         double tau_max) {
    for (auto &val : pheromone.raw()) {
        if (val < tau_min) {
            val = static_cast<float>(tau_min);
        } else if (val > tau_max) {
            val = static_cast<float>(tau_max);
        }
    }
}

static void deposit_pheromone(TriMatrix &pheromone,
                              const std::vector<int> &tour,
                              double length,
                              double q) {
    float delta = static_cast<float>(q / length);
    int n = static_cast<int>(tour.size());
    for (int i = 0; i < n; ++i) {
        int a = tour[i];
        int b = tour[(i + 1) % n];
        pheromone.set(a, b, pheromone.get(a, b) + delta);
    }
}

} // namespace

// -----------------------------
// Shared utility functions
// -----------------------------
double tour_length(const DistanceProvider &dist, const std::vector<int> &tour) {
    double total = 0.0;
    int n = static_cast<int>(tour.size());
    for (int i = 0; i < n; ++i) {
        int a = tour[i];
        int b = tour[(i + 1) % n];
        total += dist.distance(a, b);
    }
    return total;
}

MMASSolver::MMASSolver(const Instance &inst,
                       const DistanceProvider &dist,
                       const CandidateList &cand,
                       const Config &cfg)
    : inst_(inst), dist_(dist), cand_(cand), cfg_(cfg) {}

MMASResult MMASSolver::run() {
    MMASResult result;

    if (cfg_.ants <= 0 || cfg_.iterations <= 0) {
        return result;
    }

    const bool use_local_search = (cfg_.local_search != "none");
    const bool use_dlb = cfg_.dlb_flag != 0;
    const bool sparse_only = (cfg_.heatmap_mode == "threshold");
    // nn_ants for construction, nn_ls for local search (ACOTSP's two candidate radii)
    const int nn_ants = std::max(1, std::min(cfg_.nn_ants, inst_.n - 1));
    const int nn_ls = std::max(1, std::min(cfg_.nn_ls, inst_.n - 1));

    // Initial nearest-neighbor tour length (used by ACOTSP to set tau_max)
    std::mt19937_64 nn_rng(static_cast<unsigned long long>(cfg_.seed));
    LSWorkspace nn_ws;
    double nn_length = nearest_neighbor_tour_length(inst_, dist_, cand_, nn_ls,
                                                     use_local_search, use_dlb,
                                                     cfg_.max_2opt_passes, nn_rng, nn_ws);

    // ACOTSP init: tau_max=1/(rho*L_nn), tau_min=tau_max/(2N)
    double tau_max = cfg_.tau_max > 0.0 ? cfg_.tau_max : (1.0 / (cfg_.rho * nn_length));
    double tau_min = cfg_.tau_min > 0.0 ? cfg_.tau_min : (tau_max / (2.0 * inst_.n));

    TriMatrix pheromone(inst_.n);
    pheromone.fill(static_cast<float>(tau_max));

    result.best_tour.assign(inst_.n, 0);
    result.best_length = std::numeric_limits<double>::max();
    result.best_iter = 0;
    result.stats.reserve(cfg_.iterations);

    int threads = cfg_.threads > 0 ? cfg_.threads : static_cast<int>(std::thread::hardware_concurrency());
    threads = std::max(1, threads);
    threads = std::min(threads, cfg_.ants);

    std::atomic<bool> stop{false};
    struct IterBest {
        double length = std::numeric_limits<double>::max();
        std::vector<int> tour;
    } iter_best;
    iter_best.tour.resize(inst_.n);

    // Each thread keeps its own best for this iteration to avoid mutex contention
    std::vector<double> thread_best_len(threads, std::numeric_limits<double>::max());
    std::vector<std::vector<int>> thread_best_tour(
        threads, std::vector<int>(inst_.n, -1));

    std::barrier start_barrier(threads + 1);
    std::barrier end_barrier(threads + 1);

    std::vector<float> total_cache;
    compute_candidate_total(pheromone, cand_, cfg_.alpha, total_cache);

    std::vector<std::thread> workers;
    workers.reserve(threads);
    for (int t = 0; t < threads; ++t) {
        workers.emplace_back([&, t]() {
            AntWorkspace ws;
            ws.init(inst_.n);
            auto &local_best_tour = thread_best_tour[t];
            double local_best_len = std::numeric_limits<double>::max();
            std::mt19937_64 rng(static_cast<unsigned long long>(cfg_.seed) + 1315423911ULL * (t + 1));
            std::uniform_real_distribution<double> u01(0.0, 1.0);
            int start = (cfg_.ants * t) / threads;
            int end = (cfg_.ants * (t + 1)) / threads;

            while (true) {
                start_barrier.arrive_and_wait();
                if (stop.load()) {
                    end_barrier.arrive_and_wait();
                    break;
                }
                local_best_len = std::numeric_limits<double>::max();
                for (int idx = start; idx < end; ++idx) {
                    (void)idx;
                    build_tour(inst_, dist_, cand_, total_cache, pheromone, cfg_,
                               nn_ants, sparse_only, ws, rng, u01);
                    double length = 0.0;
                    if (cfg_.local_search == "2opt") {
                        two_opt_first(dist_, cand_, nn_ls, use_dlb, cfg_.max_2opt_passes,
                                      ws.tour, rng, ws.ls);
                        length = tour_length(dist_, ws.tour);
                    } else if (cfg_.local_search == "2.5opt" || cfg_.local_search == "2hopt") {
                        two_h_opt_first(dist_, cand_, nn_ls, use_dlb, cfg_.max_2opt_passes,
                                        ws.tour, rng, ws.ls);
                        length = tour_length(dist_, ws.tour);
                    } else if (cfg_.local_search == "3opt") {
                        three_opt_first(dist_, cand_, nn_ls, use_dlb, cfg_.max_3opt_passes,
                                        ws.tour, rng, ws.ls);
                        length = tour_length(dist_, ws.tour);
                    } else if (cfg_.local_search == "lk") {
                        length = lk_search(dist_, cand_, ws.tour, cfg_.lk_passes, cfg_.lk_depth);
                    } else {
                        length = tour_length(dist_, ws.tour);
                    }
                    if (length < local_best_len) {
                        local_best_len = length;
                        local_best_tour = ws.tour;
                    }
                }

                thread_best_len[t] = local_best_len;
                end_barrier.arrive_and_wait();
            }
        });
    }

    // restart_best is used by ACOTSP's MMAS update schedule
    std::vector<int> restart_best_tour(inst_.n, -1);
    double restart_best_length = std::numeric_limits<double>::max();
    int restart_found_best = 0;
    int restart_iteration = 1;
    int u_gb = cfg_.u_gb > 0 ? cfg_.u_gb : 25;

    auto start_time = std::chrono::steady_clock::now();
    for (int iter = 1; iter <= cfg_.iterations; ++iter) {
        // Signal threads to start tour construction
        start_barrier.arrive_and_wait();
        end_barrier.arrive_and_wait();

        iter_best.length = std::numeric_limits<double>::max();
        for (int t = 0; t < threads; ++t) {
            double len = thread_best_len[t];
            if (len < iter_best.length) {
                iter_best.length = len;
                iter_best.tour = thread_best_tour[t];
            }
        }

        // Update best-so-far / restart-best and adjust tau_min/tau_max per ACOTSP rules
        if (iter_best.length < result.best_length) {
            result.best_length = iter_best.length;
            result.best_tour = iter_best.tour;
            result.best_iter = iter;
            restart_best_length = iter_best.length;
            restart_best_tour = iter_best.tour;
            restart_found_best = iter;

            if (cfg_.tau_max <= 0.0) {
                tau_max = 1.0 / (cfg_.rho * result.best_length);
            }
            if (cfg_.tau_min <= 0.0) {
                if (use_local_search) {
                    tau_min = tau_max / (2.0 * inst_.n);
                } else {
                    double p = cfg_.p_best;
                    if (p <= 0.0 || p >= 1.0) {
                        p = 0.05;
                    }
                    double p_x = std::exp(std::log(p) / static_cast<double>(inst_.n));
                    double denom = p_x * ((nn_ants + 1) / 2.0);
                    tau_min = (denom <= 0.0) ? (tau_max * 0.1) : (tau_max * (1.0 - p_x) / denom);
                }
            }
        } else if (iter_best.length < restart_best_length) {
            restart_best_length = iter_best.length;
            restart_best_tour = iter_best.tour;
            restart_found_best = iter;
        }

        // Evaporation phase: with local search, evaporate only candidate edges and clamp to tau_min
        if (use_local_search) {
            mmas_evaporation_nn_list(pheromone, cand_, nn_ants, cfg_.rho, tau_min);
        } else {
            float evap = static_cast<float>(1.0 - cfg_.rho);
            for (auto &val : pheromone.raw()) {
                val *= evap;
            }
        }

        // MMAS update schedule: iteration-best / restart-best / best-so-far
        const std::vector<int> *deposit_tour = nullptr;
        double deposit_length = 0.0;
        if (iter % u_gb) {
            deposit_tour = &iter_best.tour;
            deposit_length = iter_best.length;
        } else {
            if (u_gb == 1 && (iter - restart_found_best > 50)) {
                deposit_tour = &result.best_tour;
                deposit_length = result.best_length;
            } else {
                deposit_tour = &restart_best_tour;
                deposit_length = restart_best_length;
            }
        }
        if (deposit_tour && deposit_length > 0.0) {
            deposit_pheromone(pheromone, *deposit_tour, deposit_length, cfg_.q);
        }

        if (!use_local_search) {
            check_pheromone_trail_limits(pheromone, tau_min, tau_max);
        }

        // ACOTSP u_gb schedule; only effective when local search is enabled
        if (use_local_search) {
            int since_restart = iter - restart_iteration;
            if (since_restart < 25) {
                u_gb = 25;
            } else if (since_restart < 75) {
                u_gb = 5;
            } else if (since_restart < 125) {
                u_gb = 3;
            } else if (since_restart < 250) {
                u_gb = 2;
            } else {
                u_gb = 1;
            }
        } else {
            u_gb = 25;
        }

        // Restart pheromones when branching factor drops below the threshold
        if (iter % 100 == 0) {
            double branching_factor = node_branching(pheromone, cand_, nn_ants, cfg_.lambda);
            if (branching_factor < cfg_.branch_fac && (iter - restart_found_best > 250)) {
                restart_best_length = std::numeric_limits<double>::max();
                pheromone.fill(static_cast<float>(tau_max));
                restart_iteration = iter;
            }
        }

        compute_candidate_total(pheromone, cand_, cfg_.alpha, total_cache);

        auto now = std::chrono::steady_clock::now();
        double elapsed = std::chrono::duration<double>(now - start_time).count();
        double gap = 0.0;
        if (inst_.opt_length > 0.0) {
            gap = 100.0 * (result.best_length - inst_.opt_length) / inst_.opt_length;
        }
        result.stats.push_back({iter, result.best_length, iter_best.length, elapsed, gap});

        if (cfg_.max_time_seconds > 0 && elapsed >= cfg_.max_time_seconds) {
            break;
        }
    }

    // Signal threads to exit
    stop.store(true);
    start_barrier.arrive_and_wait();
    end_barrier.arrive_and_wait();
    for (auto &th : workers) {
        th.join();
    }

    return result;
}
