#include "runner.h"

#include <cctype>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>

#include "candidates.h"
#include "dataset.h"
#include "distance.h"
#include "mmas.h"
#include "paths.h"

static std::string to_lower(std::string s) {
    for (char &c : s) {
        c = static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
    }
    return s;
}

static bool is_confidence_heatmap(const Config &cfg) {
    std::string name;
    if (!cfg.heatmap_type.empty()) {
        name = cfg.heatmap_type;
    } else if (!cfg.heatmap_dir.empty()) {
        name = std::filesystem::path(cfg.heatmap_dir).filename().string();
    }
    std::string lowered = to_lower(name);
    auto is_conf = [&](const std::string &needle) {
        return !needle.empty() && lowered == needle;
    };
    if (is_conf("difusco") || is_conf("utsp") || is_conf("attgcn") || is_conf("dimes")) {
        return true;
    }
    if (!cfg.heatmap_file.empty()) {
        std::string file_lower = to_lower(cfg.heatmap_file);
        if (file_lower.find("difusco") != std::string::npos
            || file_lower.find("utsp") != std::string::npos
            || file_lower.find("attgcn") != std::string::npos
            || file_lower.find("dimes") != std::string::npos) {
            return true;
        }
    }
    return false;
}

int run_with_config(Config cfg, int n) {
    if (cfg.output_dir.empty()) {
        cfg.output_dir = default_output_dir(cfg, n);
    }
    std::filesystem::create_directories(cfg.output_dir);

    DataSet dataset(cfg.data_file, n);
    int total = dataset.count();

    if (cfg.run_all) {
        std::ofstream summary(cfg.output_dir + "/summary_all.csv");
        std::ofstream convergence(cfg.output_dir + "/convergence_all.csv");
        summary << "instance_index,num_nodes,best_length,best_gap,opt_length,best_iteration,"
                   "iterations_run,total_seconds,seed,ants,iterations,candidate_list_size,nn_ants,"
                   "nn_ls,alpha,beta,rho,q,q0,p_best,tau_min,tau_max,lambda,branch_fac,u_gb,dlb_flag,"
                   "local_search,max_2opt_passes,max_3opt_passes,lk_passes,lk_depth,heatmap_mode,"
                   "heatmap_threshold,confidence_gamma,heatmap_path,heatmap_root,heatmap_dir,heatmap_type,distance_mode,"
                   "distance_precompute_limit,threads,data_file\n";
        convergence << "instance_index,iteration,best_length,iter_best_length,elapsed_seconds,best_gap\n";
        summary << std::setprecision(10);
        convergence << std::setprecision(10);

        for (int idx = 0; idx < total; ++idx) {
            Instance inst = dataset.load_instance(idx);
            DistanceProvider dist(inst, cfg);
            if (inst.has_opt && inst.opt_tour.size() == static_cast<std::size_t>(inst.n)) {
                inst.opt_length = tour_length(dist, inst.opt_tour);
            }

            int nn_ants = std::min(cfg.nn_ants, inst.n - 1);
            int nn_ls = std::min(cfg.nn_ls, inst.n - 1);
            int cand_k = std::max(nn_ants, nn_ls);
            if (cand_k <= 0) {
                std::cerr << "candidate list size must be >= 1\n";
                return 1;
            }

            std::cout << "[MMAS] Running instance " << idx << " / " << total
                      << " (N=" << inst.n << ")\n";

            CandidateList nn = build_nn_candidates(inst, dist, cand_k, cfg.beta, cfg.threads);
            CandidateList cand = nn;
            std::string heatmap_path;
            if (cfg.heatmap_mode == "threshold") {
                bool confidence_mode = is_confidence_heatmap(cfg);
                heatmap_path = build_heatmap_path(cfg, inst.n, idx);
                if (heatmap_path.empty()) {
                    std::cerr << "heatmap_file or heatmap_dir is required when heatmap_mode=threshold\n";
                    return 1;
                }
                cand = build_heatmap_candidates(inst, dist, nn, cand_k, cfg.beta,
                                                heatmap_path, cfg.heatmap_threshold, confidence_mode,
                                                cfg.confidence_gamma);
            }

            MMASSolver solver(inst, dist, cand, cfg);
            MMASResult result = solver.run();
            for (const auto &st : result.stats) {
                convergence << idx << "," << st.iteration << "," << st.best << ","
                            << st.iter_best << "," << st.elapsed << "," << st.gap << "\n";
            }

            double best_gap = 0.0;
            if (inst.opt_length > 0.0) {
                best_gap = 100.0 * (result.best_length - inst.opt_length) / inst.opt_length;
            }
            double total_seconds = result.stats.empty() ? 0.0 : result.stats.back().elapsed;
            summary << idx << "," << inst.n << "," << result.best_length << "," << best_gap
                    << "," << inst.opt_length << "," << result.best_iter << ","
                    << result.stats.size() << "," << total_seconds << ","
                    << cfg.seed << "," << cfg.ants << "," << cfg.iterations << ","
                    << cand_k << "," << nn_ants << "," << nn_ls << ","
                    << cfg.alpha << "," << cfg.beta << "," << cfg.rho << ","
                    << cfg.q << "," << cfg.q0 << "," << cfg.p_best << ","
                    << cfg.tau_min << "," << cfg.tau_max << ","
                    << cfg.lambda << "," << cfg.branch_fac << "," << cfg.u_gb << "," << cfg.dlb_flag << ","
                    << cfg.local_search << "," << cfg.max_2opt_passes << "," << cfg.max_3opt_passes << ","
                    << cfg.lk_passes << "," << cfg.lk_depth << "," << cfg.heatmap_mode << ","
                    << cfg.heatmap_threshold << "," << cfg.confidence_gamma << ","
                    << heatmap_path << "," << cfg.heatmap_root << "," << cfg.heatmap_dir << ","
                    << cfg.heatmap_type << "," << cfg.distance_mode << ","
                    << cfg.distance_precompute_limit << "," << cfg.threads << ","
                    << cfg.data_file << "\n";
        }
        return 0;
    }

    // Single-instance mode
    if (cfg.instance_index < 0 || cfg.instance_index >= total) {
        std::cerr << "instance_index out of range\n";
        return 1;
    }

    Instance inst = dataset.load_instance(cfg.instance_index);
    DistanceProvider dist(inst, cfg);
    if (inst.has_opt && inst.opt_tour.size() == static_cast<std::size_t>(inst.n)) {
        inst.opt_length = tour_length(dist, inst.opt_tour);
    }

    int nn_ants = std::min(cfg.nn_ants, inst.n - 1);
    int nn_ls = std::min(cfg.nn_ls, inst.n - 1);
    int cand_k = std::max(nn_ants, nn_ls);
    if (cand_k <= 0) {
        std::cerr << "candidate list size must be >= 1\n";
        return 1;
    }

    std::cout << "[MMAS] Loading instance N=" << inst.n
              << " index=" << cfg.instance_index << "\n";

    CandidateList nn = build_nn_candidates(inst, dist, cand_k, cfg.beta, cfg.threads);
    CandidateList cand = nn;
    std::string heatmap_path;
    if (cfg.heatmap_mode == "threshold") {
        bool confidence_mode = is_confidence_heatmap(cfg);
        heatmap_path = build_heatmap_path(cfg, inst.n, cfg.instance_index);
        if (heatmap_path.empty()) {
            std::cerr << "heatmap_file or heatmap_dir is required when heatmap_mode=threshold\n";
            return 1;
        }
        cand = build_heatmap_candidates(inst, dist, nn, cand_k, cfg.beta,
                                        heatmap_path, cfg.heatmap_threshold, confidence_mode,
                                        cfg.confidence_gamma);
    }

    MMASSolver solver(inst, dist, cand, cfg);
    MMASResult result = solver.run();

    {
        std::ofstream out(cfg.output_dir + "/convergence.csv");
        out << "iteration,best_length,iter_best_length,elapsed_seconds,best_gap\n";
        out << std::setprecision(10);
        for (const auto &st : result.stats) {
            out << st.iteration << "," << st.best << "," << st.iter_best
                << "," << st.elapsed << "," << st.gap << "\n";
        }
    }

    {
        std::ofstream out(cfg.output_dir + "/summary.csv");
        out << "metric,value\n";
        out << "best_length," << std::setprecision(10) << result.best_length << "\n";
        if (inst.opt_length > 0.0) {
            out << "opt_length," << std::setprecision(10) << inst.opt_length << "\n";
            out << "best_gap," << std::setprecision(10)
                << 100.0 * (result.best_length - inst.opt_length) / inst.opt_length << "\n";
        }
        out << "best_iteration," << result.best_iter << "\n";
        out << "iterations_run," << result.stats.size() << "\n";
        out << "total_seconds," << std::setprecision(6)
            << (result.stats.empty() ? 0.0 : result.stats.back().elapsed) << "\n";
        out << "seed," << cfg.seed << "\n";
        out << "ants," << cfg.ants << "\n";
        out << "iterations," << cfg.iterations << "\n";
        out << "candidate_list_size," << cand_k << "\n";
        out << "nn_ants," << nn_ants << "\n";
        out << "nn_ls," << nn_ls << "\n";
        out << "alpha," << cfg.alpha << "\n";
        out << "beta," << cfg.beta << "\n";
        out << "rho," << cfg.rho << "\n";
        out << "q," << cfg.q << "\n";
        out << "q0," << cfg.q0 << "\n";
        out << "p_best," << cfg.p_best << "\n";
        out << "tau_min," << cfg.tau_min << "\n";
        out << "tau_max," << cfg.tau_max << "\n";
        out << "lambda," << cfg.lambda << "\n";
        out << "branch_fac," << cfg.branch_fac << "\n";
        out << "u_gb," << cfg.u_gb << "\n";
        out << "dlb_flag," << cfg.dlb_flag << "\n";
        out << "local_search," << cfg.local_search << "\n";
        out << "max_2opt_passes," << cfg.max_2opt_passes << "\n";
        out << "max_3opt_passes," << cfg.max_3opt_passes << "\n";
        out << "lk_passes," << cfg.lk_passes << "\n";
        out << "lk_depth," << cfg.lk_depth << "\n";
        out << "heatmap_mode," << cfg.heatmap_mode << "\n";
        out << "heatmap_threshold," << cfg.heatmap_threshold << "\n";
        out << "confidence_gamma," << cfg.confidence_gamma << "\n";
        out << "heatmap_path," << heatmap_path << "\n";
        out << "heatmap_root," << cfg.heatmap_root << "\n";
        out << "heatmap_dir," << cfg.heatmap_dir << "\n";
        out << "heatmap_type," << cfg.heatmap_type << "\n";
        out << "distance_mode," << cfg.distance_mode << "\n";
        out << "distance_precompute_limit," << cfg.distance_precompute_limit << "\n";
        out << "threads," << cfg.threads << "\n";
        out << "data_file," << cfg.data_file << "\n";
        out << "instance_index," << cfg.instance_index << "\n";
        out << "num_nodes," << inst.n << "\n";
    }

    {
        std::ofstream out(cfg.output_dir + "/best_tour.txt");
        for (std::size_t i = 0; i < result.best_tour.size(); ++i) {
            if (i) {
                out << ' ';
            }
            out << (result.best_tour[i] + 1);
        }
        out << "\n";
    }

    std::cout << "[MMAS] Best length: " << result.best_length << "\n";
    return 0;
}
