#include "local_search.h"

#include <algorithm>
#include <array>
#include <cstdio>
#include <cstdlib>
#include <limits>
#include <numeric>

static void fill_random_permutation(std::vector<int> &perm, std::mt19937_64 &rng) {
    std::iota(perm.begin(), perm.end(), 0);
    std::shuffle(perm.begin(), perm.end(), rng);
}

void LSWorkspace::ensure(int n) {
    if (static_cast<int>(tour_ext.size()) != n + 1) {
        tour_ext.resize(n + 1);
    }
    if (static_cast<int>(pos.size()) != n) {
        pos.resize(n);
    }
    if (static_cast<int>(dlb.size()) != n) {
        dlb.resize(n);
    }
    if (static_cast<int>(random.size()) != n) {
        random.resize(n);
    }
    if (static_cast<int>(h_tour.size()) != n) {
        h_tour.resize(n);
    }
    if (static_cast<int>(hh_tour.size()) != n) {
        hh_tour.resize(n);
    }
}

void two_opt_first(const DistanceProvider &dist,
                   const CandidateList &cand,
                   int nn_ls,
                   bool use_dlb,
                   int max_passes,
                   std::vector<int> &tour,
                   std::mt19937_64 &rng,
                   LSWorkspace &ws) {
    // ACOTSP 2-opt: random scan + DLB + fixed-radius search (only over nn_ls candidates)
    int n = static_cast<int>(tour.size());
    nn_ls = std::min(nn_ls, cand.k);
    ws.ensure(n);
    auto &t = ws.tour_ext;
    auto &pos = ws.pos;
    auto &dlb = ws.dlb;
    auto &rand_vec = ws.random;

    for (int i = 0; i < n; ++i) {
        t[i] = tour[i];
        pos[t[i]] = i;
        dlb[i] = 0;
    }
    t[n] = t[0];

    fill_random_permutation(rand_vec, rng);

    bool improvement_flag = true;
    int pass = 0;
    while (improvement_flag) {
        if (max_passes > 0 && pass >= max_passes) {
            break;
        }
        improvement_flag = false;

        for (int l = 0; l < n; ++l) {
            int c1 = rand_vec[l];
            if (use_dlb && dlb[c1]) {
                continue;
            }
            int pos_c1 = pos[c1];
            int s_c1 = t[pos_c1 + 1];
            double radius = dist.distance(c1, s_c1);

            int h1 = 0, h2 = 0, h3 = 0, h4 = 0;
            int c2 = 0, s_c2 = 0, p_c1 = 0, p_c2 = 0;
            double gain = 0.0;

            // 1) successor(c1) is one endpoint of the broken edge
            for (int h = 0; h < nn_ls; ++h) {
                c2 = cand.row(c1)[h];
                if (radius > dist.distance(c1, c2)) {
                    s_c2 = t[pos[c2] + 1];
                    gain = -radius + dist.distance(c1, c2)
                           + dist.distance(s_c1, s_c2) - dist.distance(c2, s_c2);
                    if (gain < -1e-12) {
                        h1 = c1;
                        h2 = s_c1;
                        h3 = c2;
                        h4 = s_c2;
                        goto exchange2opt;
                    }
                } else {
                    break;
                }
            }

            // 2) predecessor(c1) is one endpoint of the broken edge
            p_c1 = (pos_c1 > 0) ? t[pos_c1 - 1] : t[n - 1];
            radius = dist.distance(p_c1, c1);
            for (int h = 0; h < nn_ls; ++h) {
                c2 = cand.row(c1)[h];
                if (radius > dist.distance(c1, c2)) {
                    int pos_c2 = pos[c2];
                    p_c2 = (pos_c2 > 0) ? t[pos_c2 - 1] : t[n - 1];
                    if (p_c2 == c1 || p_c1 == c2) {
                        continue;
                    }
                    gain = -radius + dist.distance(c1, c2)
                           + dist.distance(p_c1, p_c2) - dist.distance(p_c2, c2);
                    if (gain < -1e-12) {
                        h1 = p_c1;
                        h2 = c1;
                        h3 = p_c2;
                        h4 = c2;
                        goto exchange2opt;
                    }
                } else {
                    break;
                }
            }

            dlb[c1] = 1;
            continue;

            exchange2opt:
            improvement_flag = true;
            dlb[h1] = 0;
            dlb[h2] = 0;
            dlb[h3] = 0;
            dlb[h4] = 0;

            if (pos[h3] < pos[h1]) {
                std::swap(h1, h3);
                std::swap(h2, h4);
            }

            if (pos[h3] - pos[h2] < n / 2 + 1) {
                int i = pos[h2];
                int j = pos[h3];
                while (i < j) {
                    int c_a = t[i];
                    int c_b = t[j];
                    t[i] = c_b;
                    t[j] = c_a;
                    pos[c_a] = j;
                    pos[c_b] = i;
                    i++;
                    j--;
                }
            } else {
                int i = pos[h1];
                int j = pos[h4];
                int help = (j > i) ? (n - (j - i) + 1) : ((i - j) + 1);
                help /= 2;
                for (int h = 0; h < help; ++h) {
                    int c_a = t[i];
                    int c_b = t[j];
                    t[i] = c_b;
                    t[j] = c_a;
                    pos[c_a] = j;
                    pos[c_b] = i;
                    i--;
                    j++;
                    if (i < 0) {
                        i = n - 1;
                    }
                    if (j >= n) {
                        j = 0;
                    }
                }
                t[n] = t[0];
            }
        }

        if (improvement_flag) {
            pass++;
        }
    }

    for (int i = 0; i < n; ++i) {
        tour[i] = t[i];
    }
}

void two_h_opt_first(const DistanceProvider &dist,
                     const CandidateList &cand,
                     int nn_ls,
                     bool use_dlb,
                     int max_passes,
                     std::vector<int> &tour,
                     std::mt19937_64 &rng,
                     LSWorkspace &ws) {
    // ACOTSP 2.5-opt: keep original behavior (with DLB)
    int n = static_cast<int>(tour.size());
    nn_ls = std::min(nn_ls, cand.k);
    ws.ensure(n);
    auto &t = ws.tour_ext;
    auto &pos = ws.pos;
    auto &dlb = ws.dlb;
    auto &rand_vec = ws.random;

    for (int i = 0; i < n; ++i) {
        t[i] = tour[i];
        pos[t[i]] = i;
        dlb[i] = 0;
    }
    t[n] = t[0];

    fill_random_permutation(rand_vec, rng);

    bool improvement_flag = true;
    int pass = 0;
    while (improvement_flag) {
        if (max_passes > 0 && pass >= max_passes) {
            break;
        }
        improvement_flag = false;
        bool two_move = false;
        bool node_move = false;

        for (int l = 0; l < n; ++l) {
            int c1 = rand_vec[l];
            if (use_dlb && dlb[c1]) {
                continue;
            }

            bool improve_node = false;
            int pos_c1 = pos[c1];
            int s_c1 = t[pos_c1 + 1];
            double radius = dist.distance(c1, s_c1);

            int h1 = 0, h2 = 0, h3 = 0, h4 = 0, h5 = 0;
            int c2 = 0, s_c2 = 0, p_c1 = 0, p_c2 = 0;
            double gain = 0.0;

            // successor(c1) as the broken edge endpoint
            for (int h = 0; h < nn_ls; ++h) {
                c2 = cand.row(c1)[h];
                if (radius > dist.distance(c1, c2)) {
                    int pos_c2 = pos[c2];
                    s_c2 = t[pos_c2 + 1];
                    gain = -radius + dist.distance(c1, c2)
                           + dist.distance(s_c1, s_c2) - dist.distance(c2, s_c2);
                    if (gain < -1e-12) {
                        h1 = c1;
                        h2 = s_c1;
                        h3 = c2;
                        h4 = s_c2;
                        improve_node = true;
                        two_move = true;
                        node_move = false;
                        goto exchange2h;
                    }

                    if (pos_c2 > 0) {
                        p_c2 = t[pos_c2 - 1];
                    } else {
                        p_c2 = t[n - 1];
                    }
                    gain = -radius + dist.distance(c1, c2) + dist.distance(c2, s_c1)
                           + dist.distance(p_c2, s_c2) - dist.distance(c2, s_c2)
                           - dist.distance(p_c2, c2);
                    if (c2 == s_c1 || p_c2 == s_c1) {
                        gain = 0;
                    }
                    gain = 0;
                    if (gain < -1e-12) {
                        h1 = c1;
                        h2 = s_c1;
                        h3 = c2;
                        h4 = p_c2;
                        h5 = s_c2;
                        improve_node = true;
                        node_move = true;
                        two_move = false;
                        goto exchange2h;
                    }
                } else {
                    break;
                }
            }

            // predecessor(c1) as the broken edge endpoint
            p_c1 = (pos_c1 > 0) ? t[pos_c1 - 1] : t[n - 1];
            radius = dist.distance(p_c1, c1);
            for (int h = 0; h < nn_ls; ++h) {
                c2 = cand.row(c1)[h];
                if (radius > dist.distance(c1, c2)) {
                    int pos_c2 = pos[c2];
                    p_c2 = (pos_c2 > 0) ? t[pos_c2 - 1] : t[n - 1];
                    if (p_c2 == c1 || p_c1 == c2) {
                        continue;
                    }
                    gain = -radius + dist.distance(c1, c2)
                           + dist.distance(p_c1, p_c2) - dist.distance(p_c2, c2);
                    if (gain < -1e-12) {
                        h1 = p_c1;
                        h2 = c1;
                        h3 = p_c2;
                        h4 = c2;
                        improve_node = true;
                        two_move = true;
                        node_move = false;
                        goto exchange2h;
                    }

                    s_c2 = t[pos_c2 + 1];
                    gain = -radius + dist.distance(c2, c1) + dist.distance(p_c1, c2)
                           + dist.distance(p_c2, s_c2) - dist.distance(c2, s_c2)
                           - dist.distance(p_c2, c2);
                    if (p_c1 == c2 || p_c1 == s_c2) {
                        gain = 0;
                    }
                    if (gain < -1e-12) {
                        h1 = p_c1;
                        h2 = c1;
                        h3 = c2;
                        h4 = p_c2;
                        h5 = s_c2;
                        improve_node = true;
                        node_move = true;
                        two_move = false;
                        goto exchange2h;
                    }
                } else {
                    break;
                }
            }

            exchange2h:
            if (improve_node) {
                if (two_move) {
                    improvement_flag = true;
                    dlb[h1] = 0;
                    dlb[h2] = 0;
                    dlb[h3] = 0;
                    dlb[h4] = 0;
                    if (pos[h3] < pos[h1]) {
                        std::swap(h1, h3);
                        std::swap(h2, h4);
                    }
                    if (pos[h3] - pos[h2] < n / 2 + 1) {
                        int i = pos[h2];
                        int j = pos[h3];
                        while (i < j) {
                            int c_a = t[i];
                            int c_b = t[j];
                            t[i] = c_b;
                            t[j] = c_a;
                            pos[c_a] = j;
                            pos[c_b] = i;
                            i++;
                            j--;
                        }
                    } else {
                        int i = pos[h1];
                        int j = pos[h4];
                        int help = (j > i) ? (n - (j - i) + 1) : ((i - j) + 1);
                        help /= 2;
                        for (int h = 0; h < help; ++h) {
                            int c_a = t[i];
                            int c_b = t[j];
                            t[i] = c_b;
                            t[j] = c_a;
                            pos[c_a] = j;
                            pos[c_b] = i;
                            i--;
                            j++;
                            if (i < 0) {
                                i = n - 1;
                            }
                            if (j >= n) {
                                j = 0;
                            }
                        }
                        t[n] = t[0];
                    }
                } else if (node_move) {
                    improvement_flag = true;
                    dlb[h1] = 0;
                    dlb[h2] = 0;
                    dlb[h3] = 0;
                    dlb[h4] = 0;
                    dlb[h5] = 0;
                    if (pos[h3] < pos[h1]) {
                        int help = pos[h1] - pos[h3];
                        int i = pos[h3];
                        for (int h = 0; h < help; ++h) {
                            int c_a = t[i + 1];
                            t[i] = c_a;
                            pos[c_a] = i;
                            i++;
                        }
                        t[i] = h3;
                        pos[h3] = i;
                        t[n] = t[0];
                    } else {
                        int help = pos[h3] - pos[h1];
                        int i = pos[h3];
                        for (int h = 0; h < help - 1; ++h) {
                            int c_a = t[i - 1];
                            t[i] = c_a;
                            pos[c_a] = i;
                            i--;
                        }
                        t[i] = h3;
                        pos[h3] = i;
                        t[n] = t[0];
                    }
                } else {
                    std::fprintf(stderr, "2.5-opt move state error\n");
                    std::exit(1);
                }
                two_move = false;
                node_move = false;
            } else {
                dlb[c1] = 1;
            }
        }

        if (improvement_flag) {
            pass++;
        }
    }

    for (int i = 0; i < n; ++i) {
        tour[i] = t[i];
    }
}

void three_opt_first(const DistanceProvider &dist,
                     const CandidateList &cand,
                     int nn_ls,
                     bool use_dlb,
                     int max_passes,
                     std::vector<int> &tour,
                     std::mt19937_64 &rng,
                     LSWorkspace &ws) {
    // ACOTSP 3-opt: strict port of original logic (search only nn_ls candidates)
    int n = static_cast<int>(tour.size());
    nn_ls = std::min(nn_ls, cand.k);
    ws.ensure(n);
    auto &t = ws.tour_ext;
    auto &pos = ws.pos;
    auto &dlb = ws.dlb;
    auto &rand_vec = ws.random;
    auto &h_tour = ws.h_tour;
    auto &hh_tour = ws.hh_tour;

    for (int i = 0; i < n; ++i) {
        t[i] = tour[i];
        pos[t[i]] = i;
        dlb[i] = 0;
    }
    t[n] = t[0];

    fill_random_permutation(rand_vec, rng);

    bool improvement_flag = true;
    int pass = 0;
    while (improvement_flag) {
        if (max_passes > 0 && pass >= max_passes) {
            break;
        }
        double move_value = 0.0;
        improvement_flag = false;

        for (int l = 0; l < n; ++l) {
            int c1 = rand_vec[l];
            if (use_dlb && dlb[c1]) {
                continue;
            }
            bool opt2_flag = false;
            int move_flag = 0;

            int pos_c1 = pos[c1];
            int s_c1 = t[pos_c1 + 1];
            int p_c1 = (pos_c1 > 0) ? t[pos_c1 - 1] : t[n - 1];

            int h1 = 0, h2 = 0, h3 = 0, h4 = 0, h5 = 0, h6 = 0;
            int c2 = 0, c3 = 0;
            int s_c2 = 0, s_c3 = 0, p_c2 = 0, p_c3 = 0;
            int pos_c2 = 0, pos_c3 = 0;
            double diffs = 0.0, diffp = 0.0;
            int between = 0;
            double radius = 0.0, add1 = 0.0, add2 = 0.0;
            double decrease_breaks = 0.0;
            int val[3] = {0, 0, 0};
            int n1 = 0, n2 = 0, n3 = 0;

            int h = 0;
            while (h < nn_ls) {
                c2 = cand.row(c1)[h];
                pos_c2 = pos[c2];
                s_c2 = t[pos_c2 + 1];
                p_c2 = (pos_c2 > 0) ? t[pos_c2 - 1] : t[n - 1];

                diffs = 0;
                diffp = 0;
                radius = dist.distance(c1, s_c1);
                add1 = dist.distance(c1, c2);

                if (radius > add1) {
                    decrease_breaks = -radius - dist.distance(c2, s_c2);
                    diffs = decrease_breaks + add1 + dist.distance(s_c1, s_c2);
                    diffp = -radius - dist.distance(c2, p_c2)
                            + dist.distance(c1, p_c2)
                            + dist.distance(s_c1, c2);
                } else {
                    break;
                }
                if (p_c2 == c1) {
                    diffp = 0;
                }
                if ((diffs < move_value) || (diffp < move_value)) {
                    improvement_flag = true;
                    if (diffs <= diffp) {
                        h1 = c1;
                        h2 = s_c1;
                        h3 = c2;
                        h4 = s_c2;
                        move_value = diffs;
                        opt2_flag = true;
                        move_flag = 0;
                    } else {
                        h1 = c1;
                        h2 = s_c1;
                        h3 = p_c2;
                        h4 = c2;
                        move_value = diffp;
                        opt2_flag = true;
                        move_flag = 0;
                    }
                }

                int g = 0;
                while (g < nn_ls) {
                    c3 = cand.row(s_c1)[g];
                    pos_c3 = pos[c3];
                    s_c3 = t[pos_c3 + 1];
                    p_c3 = (pos_c3 > 0) ? t[pos_c3 - 1] : t[n - 1];

                    if (c3 == c1) {
                        g++;
                        continue;
                    } else {
                        add2 = dist.distance(s_c1, c3);
                        if (decrease_breaks + add1 < add2) {
                            if (pos_c2 > pos_c1) {
                                between = (pos_c3 <= pos_c2 && pos_c3 > pos_c1) ? 1 : 0;
                            } else if (pos_c2 < pos_c1) {
                                between = (pos_c3 > pos_c1 || pos_c3 < pos_c2) ? 1 : 0;
                            } else {
                                between = 0;
                            }

                            if (between) {
                                double gain = decrease_breaks - dist.distance(c3, p_c3)
                                              + add1 + add2
                                              + dist.distance(p_c3, s_c2);
                                if (gain < move_value) {
                                    improvement_flag = true;
                                    move_value = gain;
                                    opt2_flag = false;
                                    move_flag = 1;
                                    h1 = c1;
                                    h2 = s_c1;
                                    h3 = c2;
                                    h4 = s_c2;
                                    h5 = p_c3;
                                    h6 = c3;
                                    goto exchange3opt;
                                }
                            } else {
                                double gain = decrease_breaks - dist.distance(c3, s_c3)
                                              + add1 + add2
                                              + dist.distance(s_c2, s_c3);
                                if (pos_c2 == pos_c3) {
                                    gain = 20000;
                                }
                                if (gain < move_value) {
                                    improvement_flag = true;
                                    move_value = gain;
                                    opt2_flag = false;
                                    move_flag = 2;
                                    h1 = c1;
                                    h2 = s_c1;
                                    h3 = c2;
                                    h4 = s_c2;
                                    h5 = c3;
                                    h6 = s_c3;
                                    goto exchange3opt;
                                }

                                gain = -radius - dist.distance(p_c2, c2)
                                       - dist.distance(p_c3, c3)
                                       + add1 + add2
                                       + dist.distance(p_c2, p_c3);
                                if (c3 == c2 || c2 == c1 || c1 == c3 || p_c2 == c1) {
                                    gain = 2000000;
                                }
                                if (gain < move_value) {
                                    improvement_flag = true;
                                    move_value = gain;
                                    opt2_flag = false;
                                    move_flag = 3;
                                    h1 = c1;
                                    h2 = s_c1;
                                    h3 = p_c2;
                                    h4 = c2;
                                    h5 = p_c3;
                                    h6 = c3;
                                    goto exchange3opt;
                                }

                                gain = -radius - dist.distance(p_c2, c2)
                                       - dist.distance(c3, s_c3)
                                       + add1 + add2
                                       + dist.distance(p_c2, s_c3);
                                if (gain < move_value) {
                                    improvement_flag = true;
                                    move_value = gain;
                                    opt2_flag = false;
                                    move_flag = 4;
                                    h1 = c1;
                                    h2 = s_c1;
                                    h3 = p_c2;
                                    h4 = c2;
                                    h5 = c3;
                                    h6 = s_c3;
                                    goto exchange3opt;
                                }
                            }
                        } else {
                            g = nn_ls + 1;
                        }
                    }
                    g++;
                }
                h++;
            }

            if (move_flag || opt2_flag) {
                exchange3opt:
                move_value = 0.0;
                if (move_flag) {
                    dlb[h1] = 0;
                    dlb[h2] = 0;
                    dlb[h3] = 0;
                    dlb[h4] = 0;
                    dlb[h5] = 0;
                    dlb[h6] = 0;
                    pos_c1 = pos[h1];
                    pos_c2 = pos[h3];
                    pos_c3 = pos[h5];

                    if (move_flag == 4) {
                        if (pos_c2 > pos_c1) {
                            n1 = pos_c2 - pos_c1;
                        } else {
                            n1 = n - (pos_c1 - pos_c2);
                        }
                        if (pos_c3 > pos_c2) {
                            n2 = pos_c3 - pos_c2;
                        } else {
                            n2 = n - (pos_c2 - pos_c3);
                        }
                        if (pos_c1 > pos_c3) {
                            n3 = pos_c1 - pos_c3;
                        } else {
                            n3 = n - (pos_c3 - pos_c1);
                        }

                        val[0] = n1;
                        val[1] = n2;
                        val[2] = n3;
                        int longest = std::numeric_limits<int>::min();
                        int idx = 0;
                        for (int g = 0; g <= 2; g++) {
                            if (longest < val[g]) {
                                longest = val[g];
                                idx = g;
                            }
                        }

                        if (idx == 0) {
                            int j = pos[h4];
                            int h_end = pos[h5];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h4];
                            i = pos[h6];
                            t[j] = t[i];
                            pos[t[i]] = j;
                            while (i != pos_c1) {
                                i++;
                                if (i >= n) {
                                    i = 0;
                                }
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                t[j] = t[i];
                                pos[t[i]] = j;
                            }

                            j++;
                            if (j >= n) {
                                j = 0;
                            }
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else if (idx == 1) {
                            int j = pos[h6];
                            int h_end = pos[h1];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h6];
                            i = pos[h2];
                            t[j] = t[i];
                            pos[t[i]] = j;
                            while (i != pos_c2) {
                                i++;
                                if (i >= n) {
                                    i = 0;
                                }
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                t[j] = t[i];
                                pos[t[i]] = j;
                            }

                            j++;
                            if (j >= n) {
                                j = 0;
                            }
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else {
                            int j = pos[h2];
                            int h_end = pos[h3];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h2];
                            i = pos[h4];
                            t[j] = t[i];
                            pos[t[i]] = j;
                            while (i != pos_c3) {
                                i++;
                                if (i >= n) {
                                    i = 0;
                                }
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                t[j] = t[i];
                                pos[t[i]] = j;
                            }

                            j++;
                            if (j >= n) {
                                j = 0;
                            }
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        }
                    } else if (move_flag == 1) {
                        if (pos_c3 < pos_c2) {
                            n1 = pos_c2 - pos_c3;
                        } else {
                            n1 = n - (pos_c3 - pos_c2);
                        }
                        if (pos_c3 > pos_c1) {
                            n2 = pos_c3 - pos_c1 + 1;
                        } else {
                            n2 = n - (pos_c1 - pos_c3 + 1);
                        }
                        if (pos_c2 > pos_c1) {
                            n3 = n - (pos_c2 - pos_c1 + 1);
                        } else {
                            n3 = pos_c1 - pos_c2 + 1;
                        }

                        val[0] = n1;
                        val[1] = n2;
                        val[2] = n3;
                        int longest = std::numeric_limits<int>::min();
                        int idx = 0;
                        for (int g = 0; g <= 2; g++) {
                            if (longest < val[g]) {
                                longest = val[g];
                                idx = g;
                            }
                        }

                        if (idx == 0) {
                            int j = pos[h5];
                            int h_end = pos[h2];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h1];
                            h_end = pos[h4];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h4];
                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }

                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else if (idx == 1) {
                            int j = pos[h3];
                            int h_end = pos[h6];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h6];
                            i = pos[h4];
                            t[j] = t[i];
                            pos[t[i]] = j;
                            while (i != pos_c1) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                if (i >= n) {
                                    i = 0;
                                }
                                t[j] = t[i];
                                pos[t[i]] = j;
                            }

                            j++;
                            if (j >= n) {
                                j = 0;
                            }
                            i = 0;
                            t[j] = h_tour[i];
                            pos[h_tour[i]] = j;
                            while (j != pos_c1) {
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                i++;
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                            }
                            t[n] = t[0];
                        } else {
                            int j = pos[h2];
                            int h_end = pos[h5];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos_c2;
                            h_end = pos[h6];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h2];
                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }

                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        }
                    } else if (move_flag == 2) {
                        if (pos_c3 < pos_c1) {
                            n1 = pos_c1 - pos_c3;
                        } else {
                            n1 = n - (pos_c3 - pos_c1);
                        }
                        if (pos_c3 > pos_c2) {
                            n2 = pos_c3 - pos_c2;
                        } else {
                            n2 = n - (pos_c2 - pos_c3);
                        }
                        if (pos_c2 > pos_c1) {
                            n3 = pos_c2 - pos_c1;
                        } else {
                            n3 = n - (pos_c1 - pos_c2);
                        }

                        val[0] = n1;
                        val[1] = n2;
                        val[2] = n3;
                        int longest = std::numeric_limits<int>::min();
                        int idx = 0;
                        for (int g = 0; g <= 2; g++) {
                            if (longest < val[g]) {
                                longest = val[g];
                                idx = g;
                            }
                        }

                        if (idx == 0) {
                            int j = pos[h3];
                            int h_end = pos[h2];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h5];
                            h_end = pos[h4];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h2];
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }

                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else if (idx == 1) {
                            int j = pos[h2];
                            int h_end = pos[h3];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h1];
                            h_end = pos[h6];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h6];
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else {
                            int j = pos[h1];
                            int h_end = pos[h6];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h4];
                            h_end = pos[h5];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h4];
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        }
                    } else if (move_flag == 3) {
                        if (pos_c3 < pos_c1) {
                            n1 = pos_c1 - pos_c3;
                        } else {
                            n1 = n - (pos_c3 - pos_c1);
                        }
                        if (pos_c3 > pos_c2) {
                            n2 = pos_c3 - pos_c2;
                        } else {
                            n2 = n - (pos_c2 - pos_c3);
                        }
                        if (pos_c2 > pos_c1) {
                            n3 = pos_c2 - pos_c1;
                        } else {
                            n3 = n - (pos_c1 - pos_c2);
                        }

                        val[0] = n1;
                        val[1] = n2;
                        val[2] = n3;
                        int longest = std::numeric_limits<int>::min();
                        int idx = 0;
                        for (int g = 0; g <= 2; g++) {
                            if (longest < val[g]) {
                                longest = val[g];
                                idx = g;
                            }
                        }

                        if (idx == 0) {
                            int j = pos[h3];
                            int h_end = pos[h2];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h2];
                            h_end = pos[h5];
                            i = pos[h4];
                            t[j] = h4;
                            pos[h4] = j;
                            while (i != h_end) {
                                i++;
                                if (i >= n) {
                                    i = 0;
                                }
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                t[j] = t[i];
                                pos[t[i]] = j;
                            }
                            j++;
                            if (j >= n) {
                                j = 0;
                            }
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else if (idx == 1) {
                            int j = pos[h3];
                            int h_end = pos[h2];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h6];
                            h_end = pos[h1];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h6];
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        } else {
                            int j = pos[h5];
                            int h_end = pos[h4];
                            int i = 0;
                            h_tour[i] = t[j];
                            n1 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                h_tour[i] = t[j];
                                n1++;
                            }

                            j = pos[h1];
                            h_end = pos[h6];
                            i = 0;
                            hh_tour[i] = t[j];
                            n2 = 1;
                            while (j != h_end) {
                                i++;
                                j--;
                                if (j < 0) {
                                    j = n - 1;
                                }
                                hh_tour[i] = t[j];
                                n2++;
                            }

                            j = pos[h4];
                            for (i = 0; i < n1; i++) {
                                t[j] = h_tour[i];
                                pos[h_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            for (i = 0; i < n2; i++) {
                                t[j] = hh_tour[i];
                                pos[hh_tour[i]] = j;
                                j++;
                                if (j >= n) {
                                    j = 0;
                                }
                            }
                            t[n] = t[0];
                        }
                    } else {
                        std::printf("3-opt move error\n");
                        std::exit(1);
                    }
                }

                if (opt2_flag) {
                    dlb[h1] = 0;
                    dlb[h2] = 0;
                    dlb[h3] = 0;
                    dlb[h4] = 0;
                    if (pos[h3] < pos[h1]) {
                        std::swap(h1, h3);
                        std::swap(h2, h4);
                    }
                    if (pos[h3] - pos[h2] < n / 2 + 1) {
                        int i = pos[h2];
                        int j = pos[h3];
                        while (i < j) {
                            int c_a = t[i];
                            int c_b = t[j];
                            t[i] = c_b;
                            t[j] = c_a;
                            pos[c_a] = j;
                            pos[c_b] = i;
                            i++;
                            j--;
                        }
                    } else {
                        int i = pos[h1];
                        int j = pos[h4];
                        int help = (j > i) ? (n - (j - i) + 1) : ((i - j) + 1);
                        help /= 2;
                        for (int h = 0; h < help; ++h) {
                            int c_a = t[i];
                            int c_b = t[j];
                            t[i] = c_b;
                            t[j] = c_a;
                            pos[c_a] = j;
                            pos[c_b] = i;
                            i--;
                            j++;
                            if (i < 0) {
                                i = n - 1;
                            }
                            if (j >= n) {
                                j = 0;
                            }
                        }
                        t[n] = t[0];
                    }
                }
            } else {
                dlb[c1] = 1;
            }
        }

        if (improvement_flag) {
            pass++;
        }
    }

    for (int i = 0; i < n; ++i) {
        tour[i] = t[i];
    }
}

static void build_next_prev(const std::vector<int> &tour,
                            std::vector<int> &next,
                            std::vector<int> &prev) {
    int n = static_cast<int>(tour.size());
    next.assign(n, -1);
    prev.assign(n, -1);
    for (int i = 0; i < n; ++i) {
        int a = tour[i];
        int b = tour[(i + 1) % n];
        next[a] = b;
        prev[b] = a;
    }
}

static bool apply_exchange(const std::vector<int> &seq, std::vector<int> &tour) {
    // Break t1-t2, t3-t4, ... and connect t2-t3, ..., tk-t1
    int n = static_cast<int>(tour.size());
    std::vector<std::array<int, 2>> adj(n, { -1, -1 });

    for (int i = 0; i < n; ++i) {
        int a = tour[i];
        int b = tour[(i + 1) % n];
        if (adj[a][0] == -1) {
            adj[a][0] = b;
        } else {
            adj[a][1] = b;
        }
        if (adj[b][0] == -1) {
            adj[b][0] = a;
        } else {
            adj[b][1] = a;
        }
    }

    auto remove_edge = [&](int u, int v) {
        for (int i = 0; i < 2; ++i) {
            if (adj[u][i] == v) {
                adj[u][i] = -1;
                break;
            }
        }
        for (int i = 0; i < 2; ++i) {
            if (adj[v][i] == u) {
                adj[v][i] = -1;
                break;
            }
        }
    };

    auto add_edge = [&](int u, int v) {
        int slot_u = (adj[u][0] == -1) ? 0 : 1;
        int slot_v = (adj[v][0] == -1) ? 0 : 1;
        if (adj[u][slot_u] != -1 || adj[v][slot_v] != -1) {
            return false;
        }
        adj[u][slot_u] = v;
        adj[v][slot_v] = u;
        return true;
    };

    int k = static_cast<int>(seq.size() / 2);
    for (int i = 0; i < k; ++i) {
        remove_edge(seq[2 * i], seq[2 * i + 1]);
    }
    for (int i = 0; i < k - 1; ++i) {
        if (!add_edge(seq[2 * i + 1], seq[2 * i + 2])) {
            return false;
        }
    }
    if (!add_edge(seq.back(), seq.front())) {
        return false;
    }

    std::vector<int> new_tour;
    new_tour.reserve(n);
    int start = seq.front();
    int prev = -1;
    int curr = start;
    for (int i = 0; i < n; ++i) {
        new_tour.push_back(curr);
        int nxt = (adj[curr][0] != prev) ? adj[curr][0] : adj[curr][1];
        prev = curr;
        curr = nxt;
        if (curr == -1) {
            return false;
        }
    }
    if (curr != start) {
        return false;
    }
    tour = std::move(new_tour);
    return true;
}

struct LKMove {
    double gain = 0.0;
    std::vector<int> seq;
};

static void lk_dfs(const DistanceProvider &dist,
                   const CandidateList &cand,
                   const std::vector<int> &next,
                   const std::vector<int> &prev,
                   int t1,
                   std::vector<int> &seq,
                   std::vector<char> &used,
                   double gain,
                   int depth,
                   int max_depth,
                   LKMove &best) {
    int t_prev = seq[seq.size() - 2];
    int t_even = seq.back();
    const int *row = cand.row(t_even);
    for (int c = 0; c < cand.k; ++c) {
        int t_odd = row[c];
        if (t_odd == t1 || used[t_odd]) {
            continue;
        }
        if (t_odd == next[t_even] || t_odd == prev[t_even]) {
            continue;
        }
        double gain1 = gain + dist.distance(t_prev, t_even) - dist.distance(t_even, t_odd);
        if (gain1 <= 1e-12) {
            continue;
        }

        int neighs[2] = { next[t_odd], prev[t_odd] };
        for (int t_even2 : neighs) {
            if (t_even2 == -1 || t_even2 == t_even || t_even2 == t1) {
                continue;
            }
            if (used[t_even2]) {
                continue;
            }
            double close_gain = gain1 + dist.distance(t_odd, t_even2) - dist.distance(t_odd, t1);
            if (close_gain > best.gain + 1e-12) {
                best.gain = close_gain;
                best.seq = seq;
                best.seq.push_back(t_odd);
                best.seq.push_back(t_even2);
            }
            if (depth + 1 >= max_depth) {
                continue;
            }
            used[t_odd] = 1;
            used[t_even2] = 1;
            seq.push_back(t_odd);
            seq.push_back(t_even2);
            lk_dfs(dist, cand, next, prev, t1, seq, used, gain1, depth + 1, max_depth, best);
            seq.pop_back();
            seq.pop_back();
            used[t_odd] = 0;
            used[t_even2] = 0;
        }
    }
}

static bool lk_improve_from(const DistanceProvider &dist,
                            const CandidateList &cand,
                            const std::vector<int> &next,
                            const std::vector<int> &prev,
                            int t1,
                            int max_depth,
                            LKMove &best) {
    best.gain = 0.0;
    best.seq.clear();
    std::vector<char> used(cand.n, 0);
    used[t1] = 1;

    int neighbors[2] = { next[t1], prev[t1] };
    for (int t2 : neighbors) {
        if (t2 == -1) {
            continue;
        }
        used[t2] = 1;
        std::vector<int> seq = { t1, t2 };
        lk_dfs(dist, cand, next, prev, t1, seq, used, 0.0, 1, max_depth, best);
        used[t2] = 0;
    }
    return best.gain > 1e-12;
}

double lk_search(const DistanceProvider &dist,
                 const CandidateList &cand,
                 std::vector<int> &tour,
                 int passes,
                 int max_depth) {
    // True Lin-Kernighan variable k-opt search (depth-limited)
    double best_len = 0.0;
    int n = static_cast<int>(tour.size());
    for (int i = 0; i < n; ++i) {
        int a = tour[i];
        int b = tour[(i + 1) % n];
        best_len += dist.distance(a, b);
    }

    for (int pass = 0; pass < passes; ++pass) {
        bool pass_improved = false;
        bool improved = true;
        while (improved) {
            improved = false;
            std::vector<int> next, prev;
            build_next_prev(tour, next, prev);
            for (int i = 0; i < n; ++i) {
                int t1 = tour[i];
                LKMove move;
                if (lk_improve_from(dist, cand, next, prev, t1, max_depth, move)) {
                    std::vector<int> new_tour = tour;
                    if (apply_exchange(move.seq, new_tour)) {
                        tour = std::move(new_tour);
                        best_len = 0.0;
                        for (int j = 0; j < n; ++j) {
                            int a = tour[j];
                            int b = tour[(j + 1) % n];
                            best_len += dist.distance(a, b);
                        }
                        improved = true;
                        pass_improved = true;
                        break;
                    }
                }
            }
        }
        if (!pass_improved) {
            break;
        }
    }
    return best_len;
}
