# oracle_utils/lkh3_cvrp.py
"""
Lightweight LKH-3 wrapper for CVRP using VRPLIB (TSPLIB) file format.

Design goals:
- Deterministic, race-free file handling under parallel evaluation.
- Minimal contract: input instance with distance_matrix_int + demands, output integer tour length.
- Keep interface stable for oracle.py:
    from .oracle_utils.lkh3_cvrp import LKH3CVRPOracle, LKHResult
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Dict, Any, List
import os
import re
import shutil
import subprocess
import tempfile


@dataclass
class LKHResult:
    cost: Optional[int]
    status: str  # "FEASIBLE" | "TIMEOUT" | "ERROR" | "NO_SOLUTION"
    stdout: str = ""
    stderr: str = ""
    tour_path: Optional[str] = None
    par_path: Optional[str] = None
    vrp_path: Optional[str] = None
    run_dir: Optional[str] = None


class LKH3CVRPOracle:
    """
    Minimal LKH-3 oracle for CVRP with EXPLICIT FULL_MATRIX distances.

    Expected instance schema:
        instance = {
            "customers": [{"demand": int, ...}, ...],
            "vehicle_capacity": int,
            "distance_matrix_int": [[int]*(n+1)]*(n+1),
        }
    """

    def __init__(
        self,
        lkh_bin: str,
        work_dir: str,
        default_time_limit_s: int = 15,
        runs: int = 1,
        move_type: int = 5,
        patching_c: int = 3,
        patching_a: int = 2,
        max_trials: int = 1000,
        trace_level: int = 0,
        debug_keep_files: bool = False,
        use_subprocess_timeout: bool = False, 
    ):
        self.lkh_bin = lkh_bin
        self.work_dir = work_dir
        self.default_time_limit_s = int(default_time_limit_s)
        self.runs = int(runs)
        self.move_type = int(move_type)
        self.patching_c = int(patching_c)
        self.patching_a = int(patching_a)
        self.max_trials = int(max_trials)
        self.trace_level = int(trace_level)
        self.debug_keep_files = bool(debug_keep_files)
        self.use_subprocess_timeout = bool(use_subprocess_timeout)

        if not os.path.isfile(self.lkh_bin):
            raise FileNotFoundError(f"LKH-3 binary not found: {self.lkh_bin}")

        os.makedirs(self.work_dir, exist_ok=True)

    # -------------------------- public API --------------------------

    def solve(self, instance: Dict[str, Any], time_limit_s: Optional[int] = None) -> LKHResult:
        """
        Run LKH-3 and return (status, integer cost).
        Uses a unique per-call directory to avoid file collisions under parallel runs.
        """
        tl = int(time_limit_s) if time_limit_s is not None else int(self.default_time_limit_s)

        # Unique per-call directory (critical for parallel evaluation)
        run_dir = tempfile.mkdtemp(prefix="lkh3_run_", dir=self.work_dir)
        vrp_path = os.path.join(run_dir, "problem.vrp")
        par_path = os.path.join(run_dir, "params.par")
        tour_path = os.path.join(run_dir, "output.tour")

        try:
            self._write_vrplib_explicit(instance, out_path=vrp_path)
            self._write_par(par_path, vrp_path, tour_path, tl)

            if self.use_subprocess_timeout:
                timeout = max(10 * tl + 60, 300)  # 比如 tl=3s -> 90s，至少 300s
            else:
                timeout = None
            
            try:
                proc = subprocess.run(
                    [self.lkh_bin, par_path],
                    cwd=run_dir,
                    capture_output=True,
                    text=True,
                    timeout=timeout,
                )
            except subprocess.TimeoutExpired as e:
                cost = self._parse_tour_length(tour_path)
                if cost is not None:
                    return LKHResult(
                        cost=int(cost),
                        status="FEASIBLE",
                        stdout=getattr(e, "stdout", "") or "",
                        stderr=getattr(e, "stderr", "") or "",
                        tour_path=tour_path,
                        par_path=par_path,
                        vrp_path=vrp_path,
                        run_dir=run_dir,
                    )
                return LKHResult(
                    cost=None,
                    status="TIMEOUT",
                    stdout=getattr(e, "stdout", "") or "",
                    stderr=getattr(e, "stderr", "") or "",
                    tour_path=tour_path,
                    par_path=par_path,
                    vrp_path=vrp_path,
                    run_dir=run_dir,
                )

            cost = self._parse_tour_length(tour_path)
            if cost is not None:
                return LKHResult(
                    cost=int(cost),
                    status="FEASIBLE",
                    stdout=proc.stdout or "",
                    stderr=proc.stderr or "",
                    tour_path=tour_path,
                    par_path=par_path,
                    vrp_path=vrp_path,
                    run_dir=run_dir,
                )

            # If no tour length found: decide between ERROR/NO_SOLUTION
            status = "ERROR" if proc.returncode != 0 else "NO_SOLUTION"
            return LKHResult(
                cost=None,
                status=status,
                stdout=proc.stdout or "",
                stderr=proc.stderr or "",
                tour_path=tour_path,
                par_path=par_path,
                vrp_path=vrp_path,
                run_dir=run_dir,
            )

        except Exception as e:
            return LKHResult(
                cost=None,
                status="ERROR",
                stdout="",
                stderr=f"{type(e).__name__}: {e}",
                tour_path=tour_path,
                par_path=par_path,
                vrp_path=vrp_path,
                run_dir=run_dir,
            )

        finally:
            if not self.debug_keep_files:
                self._safe_rmdir(run_dir)

    # -------------------------- file writers --------------------------

    def _write_par(self, par_path: str, vrp_path: str, tour_path: str, tl: int) -> None:
        """
        Write LKH parameter file.
        Keep it minimal to reduce compatibility issues.
        """
        with open(par_path, "w") as f:
            f.write(f"PROBLEM_FILE = {vrp_path}\n")
            f.write(f"OUTPUT_TOUR_FILE = {tour_path}\n")
            f.write(f"RUNS = {int(self.runs)}\n")
            f.write(f"MAX_TRIALS = {int(self.max_trials)}\n")
            f.write(f"MOVE_TYPE = {int(self.move_type)}\n")
            f.write(f"PATCHING_C = {int(self.patching_c)}\n")
            f.write(f"PATCHING_A = {int(self.patching_a)}\n")
            f.write(f"TIME_LIMIT = {int(tl)}\n")
            f.write(f"TRACE_LEVEL = {int(self.trace_level)}\n")

    def _write_vrplib_explicit(self, instance: Dict[str, Any], out_path: str) -> None:
        """
        Write a VRPLIB/TSPLIB CVRP instance with EXPLICIT FULL_MATRIX edge weights.
        NOTE: LKH requires strict formatting. Do not add extra sections.
        """
        customers = instance.get("customers", [])
        cap = int(instance["vehicle_capacity"])
        dist = instance.get("distance_matrix_int", None)
        if dist is None:
            raise ValueError("distance_matrix_int missing for LKH-3.")

        n = len(customers)
        dim = n + 1

        # Basic shape check
        if len(dist) != dim or any(len(row) != dim for row in dist):
            raise ValueError(
                f"distance_matrix_int must be shape ({dim},{dim}), got ({len(dist)},{len(dist[0]) if dist else 0})"
            )

        with open(out_path, "w") as f:
            f.write("NAME : tmp_cvrp\n")
            f.write("TYPE : CVRP\n")
            f.write(f"DIMENSION : {dim}\n")
            f.write(f"CAPACITY : {cap}\n")
            f.write("EDGE_WEIGHT_TYPE : EXPLICIT\n")
            f.write("EDGE_WEIGHT_FORMAT : FULL_MATRIX\n")
            f.write("EDGE_WEIGHT_SECTION\n")
            for i in range(dim):
                row = dist[i]
                f.write(" ".join(str(int(row[j])) for j in range(dim)) + "\n")

            f.write("DEMAND_SECTION\n")
            f.write("1 0\n")
            for idx, c in enumerate(customers, start=2):
                f.write(f"{idx} {int(c.get('demand', 0))}\n")

            f.write("DEPOT_SECTION\n")
            f.write("1\n")
            f.write("-1\n")
            f.write("EOF\n")

    # -------------------------- parsers / utils --------------------------

    def _parse_tour_length(self, tour_path: str) -> Optional[int]:
        """
        Parse 'COMMENT : Length = <int>' from the output tour file.
        """
        if not tour_path or not os.path.exists(tour_path):
            return None

        pat = re.compile(r"Length\s*=\s*(\d+)")
        try:
            with open(tour_path, "r") as f:
                for line in f:
                    m = pat.search(line)
                    if m:
                        return int(m.group(1))
        except Exception:
            return None

        return None

    def _safe_rmdir(self, run_dir: str) -> None:
        """
        Remove the entire run directory safely.
        """
        try:
            if run_dir and os.path.isdir(run_dir):
                shutil.rmtree(run_dir, ignore_errors=True)
        except Exception:
            pass
