#!/usr/bin/env python3
"""
CVRP Oracle implementation for reference upper bounds.

Supports multiple oracle backends:
- PyVRP: Python-based VRP solver (requires pyvrp package)
- LKH-3: High-quality heuristic solver (requires LKH-3 executable)

Note: This provides heuristic solutions (upper bounds), not exact/proven optimal solutions.
"""

from __future__ import annotations

import warnings
import numpy as np
from typing import Dict, Optional, Any

from .oracle_utils.lkh3_cvrp import LKH3CVRPOracle, LKHResult


def create_cvrp_oracle(config=None, oracle_type: str = "pyvrp"):
    """
    Create a CVRP oracle instance.
    
    Args:
        config: Optional config object
        oracle_type: Oracle type:
            - "pyvrp": PyVRP (Python-based VRP solver, requires pyvrp package)
            - "lkh3": LKH-3 (high-quality heuristic, requires LKH-3 executable)
            - "none": No oracle (returns None)
        
    Returns:
        CVRPOracle instance
    """
    oracle_type = oracle_type or (getattr(config, 'oracle_type', 'pyvrp') if config else 'pyvrp')
    return CVRPOracle(config, oracle_type)


class CVRPOracle:
    """
    CVRP oracle for computing reference upper bounds.
    
    Supports multiple oracle backends:
    - PyVRP: Python-based VRP solver (requires pyvrp package)
    - LKH-3: High-quality heuristic solver (requires LKH-3 executable)
    
    Note: This provides heuristic solutions (upper bounds), not exact/proven optimal solutions.
    
    Args:
        config: CVRPConfig or any config object with oracle_type, oracle_timeout, lkh3_path attributes
        oracle_type: Oracle type ("pyvrp", "lkh3", or "none")
    """
    
    def __init__(self, config=None, oracle_type: str = "pyvrp"):
        self.config = config
        self.oracle_type = oracle_type or (getattr(config, 'oracle_type', 'pyvrp') if config else 'pyvrp')
        self._validate_oracle()
        
        # Initialize oracle based on type
        self._lkh_oracle = None
        self._pyvrp_available = None
        if self.oracle_type == "lkh3":
            self._init_lkh_oracle()
        elif self.oracle_type == "pyvrp":
            self._check_pyvrp()
    
    def _validate_oracle(self):
        """Validate that the oracle type is supported."""
        supported_types = ["pyvrp", "lkh3", "none"]
        if self.oracle_type not in supported_types:
            raise ValueError(f"Unknown oracle type: {self.oracle_type}. Supported: {supported_types}")
        
        # Check if PyVRP is available if oracle_type is pyvrp
        if self.oracle_type == "pyvrp":
            try:
                from pyvrp import Model
                from pyvrp.stop import MaxRuntime
            except ImportError:
                raise ImportError(
                    "PyVRP not available. Please install it with: pip install pyvrp"
                )
    
    def _init_lkh_oracle(self):
        """Initialize LKH3CVRPOracle with config parameters."""
        import os
        import tempfile
        
        lkh3_path = getattr(self.config, 'lkh3_path', None) or "/data1//tools/LKH3-v.1.0/LKH"
        
        # Check if LKH executable exists
        if not os.path.isfile(lkh3_path):
            raise FileNotFoundError(
                f"LKH-3 executable not found at: {lkh3_path}. "
                f"Please set lkh3_path in config or ensure LKH executable exists."
            )
        
        # Use /tmp instead of /dev/shm to avoid flush issues
        work_dir = tempfile.gettempdir()  # Use system temp directory
        
        # Get parameters from config with defaults
        default_time_limit = getattr(self.config, 'oracle_timeout', 3)  # reduce to 1~3s for quick feasibility
        lkh_runs = getattr(self.config, 'lkh_runs', 1)
        lkh_max_trials = getattr(self.config, 'lkh_max_trials', 300)  # reduce to 200~500, not 1000+ for faster completion  
        debug_mode = getattr(self.config, 'debug_mode', False)
        
        self._lkh_oracle = LKH3CVRPOracle(
            lkh_bin=lkh3_path,
            work_dir=work_dir,
            default_time_limit_s=default_time_limit,
            runs=lkh_runs,
            move_type=5,
            patching_c=3,
            patching_a=2,
            max_trials=lkh_max_trials,
            trace_level=1 if debug_mode else 0,
            debug_keep_files=debug_mode,
        )
    
    def _check_pyvrp(self):
        """Check if PyVRP is available."""
        try:
            from pyvrp import Model
            from pyvrp.stop import MaxRuntime
            self._pyvrp_available = True
        except ImportError:
            self._pyvrp_available = False
    
    def solve_oracle(self, instance: Dict, timeout_seconds: Optional[int] = None) -> Optional[Dict[str, Any]]:
        """
        Solve CVRP instance using the configured oracle method.
        
        Returns a reference upper bound (heuristic solution), not an exact/proven optimal solution.
        
        Args:
            instance: CVRP instance dict with:
                - 'depot': list [x, y] coordinates
                - 'customers': list of dicts with 'coords' [x, y] and 'demand' int
                - 'vehicle_capacity': int
                - 'num_vehicles': int (optional, if None, uses minimum required vehicles)
                - 'distance_matrix_int': np.ndarray (optional, (n+1) x (n+1) integer distance matrix)
                - 'distance_scale': float (optional, default 1000.0, scale factor for distance matrix)
            timeout_seconds: Timeout in seconds (uses config value if None)
            
        Returns:
            Dict with keys:
                - 'cost': float, solution cost (total distance) or None if failed
                - 'status': str, one of "FEASIBLE", "NO_SOLUTION", "TIMEOUT", "ERROR"
                - 'solver': str, solver name ("pyvrp" or "lkh3")
            Returns None if oracle_type is "none"
        """
        if self.oracle_type == "none":
            return None
        
        # Route to appropriate solver
        if self.oracle_type == "pyvrp":
            return self._solve_with_pyvrp(instance, timeout_seconds)
        elif self.oracle_type == "lkh3":
            return self._solve_with_lkh3(instance, timeout_seconds)
        else:
            return {"cost": None, "status": "ERROR", "solver": "unknown"}
    



    def _solve_with_pyvrp(self, instance: Dict, timeout_seconds: Optional[int] = None) -> Dict[str, Any]:
        if not getattr(self, "_pyvrp_available", False):
            warnings.warn("PyVRP not available, cannot compute solution")
            return {"cost": None, "status": "ERROR", "solver": "pyvrp"}

        debug_mode = getattr(self.config, "debug_mode", False) if getattr(self, "config", None) else False

        customers = instance.get("customers", [])
        n_customers = len(customers)
        if n_customers == 0:
            return {"cost": 0.0, "status": "FEASIBLE", "solver": "pyvrp"}

        vehicle_capacity = int(instance["vehicle_capacity"])

        # --- Hard feasibility check (must) ---
        max_demand = max(int(c["demand"]) for c in customers)
        if max_demand > vehicle_capacity:
            if debug_mode:
                print(f"      [CVRP Oracle] PyVRP(Model): infeasible max_demand={max_demand} > cap={vehicle_capacity}", flush=True)
            return {"cost": None, "status": "NO_SOLUTION", "solver": "pyvrp"}

        # --- Vehicles: training-safe choice ---
        # If you cap vehicles too tightly, PyVRP may struggle to even find feasibility.
        # For training oracle, allow enough vehicles.
        num_vehicles = instance.get("num_vehicles", None)
        if num_vehicles is None:
            total_demand = sum(int(c["demand"]) for c in customers)
            base = max(1, (total_demand + vehicle_capacity - 1) // vehicle_capacity)
            buffer = int(getattr(self.config, "vehicles_buffer", 2)) if getattr(self, "config", None) else 2
            num_vehicles = max(1, base + buffer)
        else:
            num_vehicles = int(num_vehicles)

        # ensure enough vehicles (training-friendly)
        num_vehicles = min(n_customers, max(num_vehicles, n_customers))

        # --- Budget: keep small for speed ---
        iters = int(getattr(self.config, "pyvrp_max_iters", 1000)) if getattr(self, "config", None) else 1000
        seed = int(getattr(self.config, "seed", 42)) if getattr(self, "config", None) else 42
        timeout = int(timeout_seconds) if timeout_seconds is not None else int(getattr(self.config, "oracle_timeout", 2))

        try:
            cost_int = self._solve_with_pyvrp_model(
                instance=instance,
                num_vehicles=num_vehicles,
                capacity=vehicle_capacity,
                iters=iters,
                timeout=timeout,
                seed=seed,
                debug_mode=debug_mode,
            )

            if cost_int is None:
                if debug_mode:
                    print("      [CVRP Oracle] PyVRP(Model): No feasible solution", flush=True)
                return {"cost": None, "status": "NO_SOLUTION", "solver": "pyvrp"}

            cost = float(cost_int)
            if 'distance_scale' in instance:
                scale = float(instance.get("distance_scale", 1.0))
                if scale > 0:
                    cost /= scale
                    rescale = True
                else:
                    rescale = False
            else:
                # If no distance_scale, check config flag (for backward compatibility)
                rescale = bool(getattr(self.config, "pyvrp_rescale_cost", True)) if getattr(self, "config", None) else True
                if rescale:
                    # Default scale if not specified
                    scale = 1000.0
                    cost /= scale

            if debug_mode:
                print(f"      [CVRP Oracle] PyVRP(Model): Success! cost={cost:.6f} (rescale={rescale})", flush=True)

            return {"cost": cost, "status": "FEASIBLE", "solver": "pyvrp"}

        except Exception as e:
            if debug_mode:
                import traceback
                print(f"      [CVRP Oracle] PyVRP(Model): Error: {type(e).__name__}: {e}", flush=True)
                traceback.print_exc()
            warnings.warn(f"PyVRP unexpected error: {type(e).__name__}: {e}")
            return {"cost": None, "status": "ERROR", "solver": "pyvrp"}

    def _solve_with_pyvrp_model(
        self,
        instance: Dict,
        num_vehicles: int,
        capacity: int,
        iters: int = 1000,
        timeout: int = 2,
        seed: int = 42,
        debug_mode: bool = False,
    ) -> Optional[float]:
        """
        Build a PyVRP Model directly from distance_matrix_int and demands, then solve.
        Returns objective value if feasible else None.
        """
        from pyvrp import Model
        from pyvrp.stop import MaxIterations, MaxRuntime

        dist = instance.get("distance_matrix_int", None)
        if dist is None:
            raise ValueError("distance_matrix_int missing; PyVRP(Model) path requires explicit distance matrix.")

        customers = instance["customers"]
        n = len(customers)
        dim = n + 1

        # sanity check matrix shape
        if len(dist) != dim or any(len(row) != dim for row in dist):
            raise ValueError(f"distance_matrix_int must be shape ({dim},{dim}), got ({len(dist)},{len(dist[0]) if dist else 0})")

        m = Model()

        # Vehicle type API differs across versions: try common signatures.
        added = False
        for kwargs in (
            {"num_available": int(num_vehicles), "capacity": int(capacity)},
            {"num_vehicles": int(num_vehicles), "capacity": int(capacity)},
            {"capacity": int(capacity)},  # last resort: unlimited vehicles
        ):
            try:
                m.add_vehicle_type(**kwargs)
                added = True
                break
            except TypeError:
                continue
        if not added:
            raise TypeError("Could not add vehicle type: PyVRP API mismatch (num_available/num_vehicles).")

        # Locations: coords irrelevant because we use explicit distances
        depot = m.add_depot(x=0, y=0)
        clients = [m.add_client(x=0, y=0, delivery=int(c["demand"])) for c in customers]
        locs = [depot] + clients

        # Add edges from FULL_MATRIX
        for i in range(dim):
            li = locs[i]
            row = dist[i]
            for j in range(dim):
                m.add_edge(li, locs[j], distance=int(row[j]))

        # Stop criterion (choose ONE)
        stop = MaxIterations(int(iters))
        # or runtime:
        # stop = MaxRuntime(float(timeout))

        res = m.solve(stop=stop, seed=int(seed), display=False)

        if not res.is_feasible():
            return None

        sol = res.best
        return float(self._pyvrp_solution_objective(sol))




    def _pyvrp_solution_objective(self, sol) -> float:
        """
        Extract objective from PyVRP Solution across versions.

        Prefer solution.objective() if available. Fall back to a small set of
        legacy names.
        """
        # Most common in recent PyVRP
        if hasattr(sol, "objective"):
            v = getattr(sol, "objective")
            try:
                return float(v() if callable(v) else v)
            except Exception:
                pass

        # Fallback candidates (older / alternative bindings)
        for name in ("objective_value", "obj", "cost", "total_cost", "distance", "total_distance"):
            if hasattr(sol, name):
                v = getattr(sol, name)
                try:
                    return float(v() if callable(v) else v)
                except Exception:
                    pass

        raise AttributeError("Cannot extract objective/cost from PyVRP Solution (API mismatch).")








        
    def _create_vrplib_file(self, instance: Dict, num_vehicles: int) -> str:
        import tempfile

        depot = instance["depot"]
        customers = instance["customers"]
        capacity = int(instance["vehicle_capacity"])

        n = len(customers)
        dim = n + 1  # include depot

        # Path
        fd, path = tempfile.mkstemp(prefix="pyvrp_", suffix=".vrp", dir="/dev/shm")
        with open(fd, "w") as f:
            f.write("NAME : tmp_cvrp\n")
            f.write("TYPE : CVRP\n")
            f.write(f"DIMENSION : {dim}\n")
            f.write(f"CAPACITY : {capacity}\n")
            f.write(f"VEHICLES : {int(num_vehicles)}\n")

            # Prefer explicit distance matrix if provided
            dist_int = instance.get("distance_matrix_int", None)
            if dist_int is not None:
                # dist_int should be (dim x dim) int matrix
                f.write("EDGE_WEIGHT_TYPE : EXPLICIT\n")
                f.write("EDGE_WEIGHT_FORMAT : FULL_MATRIX\n")
                f.write("EDGE_WEIGHT_SECTION\n")
                for i in range(dim):
                    row = dist_int[i]
                    f.write(" ".join(str(int(x)) for x in row) + "\n")

                # Coordinates are optional when using EXPLICIT; you can omit NODE_COORD_SECTION entirely.
            else:
                # Fallback to EUC_2D using coordinates (only if you truly have no matrix)
                f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
                f.write("NODE_COORD_SECTION\n")
                f.write(f"1 {float(depot[0])} {float(depot[1])}\n")
                for idx, c in enumerate(customers, start=2):
                    x, y = c["coord"]
                    f.write(f"{idx} {float(x)} {float(y)}\n")

            # Demands
            f.write("DEMAND_SECTION\n")
            f.write("1 0\n")
            for idx, c in enumerate(customers, start=2):
                f.write(f"{idx} {int(c['demand'])}\n")

            # Depot section
            f.write("DEPOT_SECTION\n")
            f.write("1\n")
            f.write("-1\n")
            f.write("EOF\n")

        return path






    def _solve_with_lkh3(self, instance: Dict, timeout_seconds: Optional[int] = None) -> Dict[str, Any]:
        """Solve CVRP using LKH-3."""
        if self._lkh_oracle is None:
            return {"cost": None, "status": "ERROR", "solver": "lkh3"}
        
        # Ensure instance has distance_matrix_int
        if 'distance_matrix_int' not in instance:
            # Try to compute it from coordinates
            from .evolution.shared.solver.solve_instance import instance_to_solver_inputs
            instance_to_solver_inputs(instance)
        
        if 'distance_matrix_int' not in instance:
            warnings.warn("CVRP instance missing distance_matrix_int, cannot solve with LKH-3")
            return {"cost": None, "status": "ERROR", "solver": "lkh3"}
        
        # Convert numpy array to list of lists if needed
        distance_matrix_int = instance['distance_matrix_int']
        if hasattr(distance_matrix_int, 'tolist'):
            distance_matrix_int = distance_matrix_int.tolist()
        
        # Prepare instance dict for LKH3CVRPOracle
        lkh_instance = {
            'customers': instance['customers'],
            'vehicle_capacity': instance['vehicle_capacity'],
            'distance_matrix_int': distance_matrix_int,
        }
        
        # Call LKH3CVRPOracle
        try:
            debug_mode = getattr(self.config, 'debug_mode', False) if self.config else False
            
            # ensure time_limit is correctly passed, do not pass None
            # during training, should explicitly pass a small time_limit, use config to control "training oracle budget"
            tl = int(timeout_seconds) if timeout_seconds is not None else int(getattr(self.config, "oracle_timeout", 3))
            
            if debug_mode:
                print(f"      [CVRP Oracle] LKH-3: Calling solve with time_limit_s={tl}", flush=True)
            
            result: LKHResult = self._lkh_oracle.solve(lkh_instance, time_limit_s=tl)
            
            if debug_mode:
                print(f"      [CVRP Oracle] LKH-3: Result status={result.status}, cost={result.cost}", flush=True)
                if result.stderr:
                    print(f"      [CVRP Oracle] LKH-3: stderr={result.stderr[:500]}", flush=True)
            
            # Convert LKHResult to our format
            if result.status == "FEASIBLE" and result.cost is not None:
                # Convert integer cost to float using distance_scale
                distance_scale = instance.get('distance_scale', 1000.0)
                cost_float = float(result.cost) / distance_scale
                
                return {
                    "cost": cost_float,
                    "status": "FEASIBLE",
                    "solver": "lkh3"
                }
            elif result.status == "TIMEOUT":
                return {"cost": None, "status": "TIMEOUT", "solver": "lkh3"}
            elif result.status == "ERROR":
                debug_mode = getattr(self.config, 'debug_mode', False) if self.config else False
                if debug_mode:
                    print(f"      [CVRP Oracle] LKH-3: Error: {result.stderr}", flush=True)
                return {"cost": None, "status": "ERROR", "solver": "lkh3"}
            else:
                return {"cost": None, "status": "NO_SOLUTION", "solver": "lkh3"}
            
        except Exception as e:
            import traceback
            debug_mode = getattr(self.config, 'debug_mode', False) if self.config else False
            if debug_mode:
                print(f"      [CVRP Oracle] LKH-3: Exception: {type(e).__name__}: {e}", flush=True)
                traceback.print_exc()
            warnings.warn(f"LKH3 unexpected error: {type(e).__name__}: {e}")
            return {"cost": None, "status": "ERROR", "solver": "lkh3"}
    
    def solve_exact(self, instance: Dict, timeout_seconds: Optional[int] = None) -> Optional[float]:
        """
        Legacy method name for backward compatibility.
        
        Deprecated: Use solve_oracle() instead, which returns status information.
            
        Returns:
            Solution cost (float) or None if failed
        """
        result = self.solve_oracle(instance, timeout_seconds)
        if result is None:
            return None
        return result.get("cost")
