#!/usr/bin/env python3
"""
TSP Oracle implementations for computing exact solutions and gaps.
"""

import numpy as np
import subprocess
import tempfile
import os
import time
import signal
import sys
import re
from typing import List, Optional, Callable


class TSPOracle:
    """
    TSP oracle for computing exact solutions.
    
    Args:
        config: TSPGLSConfig or any config object with oracle_type, oracle_timeout/timeout_seconds, 
                lkh3_path, concorde_path, lkh_runs/runs attributes
    """
    
    def __init__(self, config):
        self.config = config
        self._validate_oracle()
    
    def _validate_oracle(self):
        """Validate that the oracle executable is available."""
        oracle_type = getattr(self.config, 'oracle_type', 'none')
        if oracle_type == "lkh3":
            lkh3_path = getattr(self.config, 'lkh3_path', None)
            if lkh3_path and not os.path.exists(lkh3_path):
                raise FileNotFoundError(f"LKH3 executable not found at {lkh3_path}")
        elif oracle_type == "concorde":
            concorde_path = getattr(self.config, 'concorde_path', None)
            if concorde_path and not os.path.exists(concorde_path):
                import warnings
                warnings.warn(
                    f"Concorde executable not found at {concorde_path}. "
                    f"Will try to use 'concorde' from PATH if available.",
                    UserWarning
                )
                self.config.concorde_path = None
    
    def solve_exact(self, coords: np.ndarray) -> float:
        """
        Solve TSP instance to optimality and return tour length.
        
        Args:
            coords: Array of shape (n_cities, 2) with city coordinates
            
        Returns:
            Optimal tour length
        """
        oracle_type = getattr(self.config, 'oracle_type', 'none')
        if oracle_type == "none":
            # Fallback: use nearest neighbor as approximation
            return self._nearest_neighbor_tour_length(coords)
        elif oracle_type == "lkh3":
            return self._solve_with_lkh3(coords)
        elif oracle_type == "concorde":
            return self._solve_with_concorde(coords)
        else:
            raise ValueError(f"Unknown oracle type: {oracle_type}")
    
    def _nearest_neighbor_tour_length(self, coords: np.ndarray) -> float:
        """Fallback: compute nearest neighbor tour length."""
        n = len(coords)
        if n <= 1:
            return 0.0
        
        # Compute distance matrix
        dist_matrix = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                dist_matrix[i, j] = np.linalg.norm(coords[i] - coords[j])
        
        # Nearest neighbor heuristic
        tour = [0]
        unvisited = set(range(1, n))
        current = 0
        
        while unvisited:
            nearest = min(unvisited, key=lambda x: dist_matrix[current, x])
            tour.append(nearest)
            unvisited.remove(nearest)
            current = nearest
        
        # Compute tour length
        total_length = 0.0
        for i in range(len(tour)):
            j = (i + 1) % len(tour)
            total_length += dist_matrix[tour[i], tour[j]]
        
        return total_length
    
    def _solve_with_lkh3(self, coords: np.ndarray) -> float:
        """Solve using LKH3 (if available) with robust timeout and cleanup."""
        tsp_file = None
        par_file = None
        process = None
        
        tmpdir = None
        if sys.platform != 'win32':
            if os.path.exists('/dev/shm'):
                tmpdir = '/dev/shm'
        
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.tsp', delete=False, dir=tmpdir) as f:
                tsp_file = f.name
                self._write_tsplib_file(coords, f)
            
            with tempfile.NamedTemporaryFile(mode='w', suffix='.par', delete=False, dir=tmpdir) as f:
                par_file = f.name
                self._write_lkh3_par_file(tsp_file, f)
            
            lkh3_path = getattr(self.config, 'lkh3_path', None)
            lkh3_cmd = lkh3_path or "/data1//tools/LKH3-v.1.0/LKH"
            popen_kwargs = {
                'stdout': subprocess.PIPE,
                'stderr': subprocess.PIPE,
                'text': True
            }
            if sys.platform != 'win32':
                popen_kwargs['start_new_session'] = True  
            process = subprocess.Popen([lkh3_cmd, par_file], **popen_kwargs)
            t0 = time.time()
            
            # Wait for process with timeout
            try:
                timeout_seconds = getattr(self.config, 'timeout_seconds', None) or getattr(self.config, 'oracle_timeout', 30)
                stdout, stderr = process.communicate(timeout=timeout_seconds)
                returncode = process.returncode
                elapsed = time.time() - t0
                
                try:
                    from ...utils.evaluation_timer import get_timer
                    get_timer().add_lkh_oracle_time(elapsed)
                except:
                    pass
                
                if returncode == 0:
                    tour_length = self._parse_lkh3_output(stdout)
                    if np.isfinite(tour_length):
                        return tour_length
                    else:
                        print(f"        LKH3 parsing failed, using nearest neighbor")
                        return self._nearest_neighbor_tour_length(coords)
                else:
                    print(f"        LKH3 failed (returncode {returncode}): {stderr[:200] if stderr else 'no error message'}")
                    return self._nearest_neighbor_tour_length(coords)
                    
            except subprocess.TimeoutExpired:
                timeout_seconds = getattr(self.config, 'timeout_seconds', None) or getattr(self.config, 'oracle_timeout', 30)
                print(f"        LKH3 timeout after {timeout_seconds}s, killing process and using nearest neighbor")
                try:
                    if sys.platform != 'win32':
                        try:
                            pgid = os.getpgid(process.pid)
                            os.killpg(pgid, signal.SIGTERM)
                            try:
                                process.wait(timeout=1)
                            except subprocess.TimeoutExpired:
                                os.killpg(pgid, signal.SIGKILL)
                        except (ProcessLookupError, OSError) as e:
                            try:
                                process.terminate()
                                process.wait(timeout=1)
                            except (subprocess.TimeoutExpired, ProcessLookupError):
                                try:
                                    process.kill()
                                except ProcessLookupError:
                                    pass  
                    else:
                        try:
                            process.terminate()
                            process.wait(timeout=1)
                        except (subprocess.TimeoutExpired, ProcessLookupError):
                            try:
                                process.kill()
                            except ProcessLookupError:
                                pass  
                except Exception:
                    try:
                        if process and process.poll() is None:
                            process.kill()
                    except Exception:
                        pass  
                return self._nearest_neighbor_tour_length(coords)
                
        except FileNotFoundError as e:
            print(f"        LKH3 executable not found: {e}, using nearest neighbor")
            return self._nearest_neighbor_tour_length(coords)
        except Exception as e:
            print(f"        LKH3 unexpected error: {type(e).__name__}: {e}, using nearest neighbor")
            return self._nearest_neighbor_tour_length(coords)
        finally:
            for file_path in [tsp_file, par_file]:
                if file_path and os.path.exists(file_path):
                    try:
                        os.unlink(file_path)
                    except OSError:
                        pass  
    
    def _write_tsplib_file(self, coords: np.ndarray, file_handle):
        """Write coordinates in TSPLIB format."""
        n = len(coords)
        file_handle.write(f"NAME: temp_tsp\n")
        file_handle.write(f"TYPE: TSP\n")
        file_handle.write(f"DIMENSION: {n}\n")
        file_handle.write(f"EDGE_WEIGHT_TYPE: EUC_2D\n")
        file_handle.write(f"NODE_COORD_SECTION\n")
        
        for i, (x, y) in enumerate(coords):
            x_int = int(round(x * 1000))
            y_int = int(round(y * 1000))
            file_handle.write(f"{i+1} {x_int} {y_int}\n")
        
        file_handle.write("EOF\n")
    
    def _write_lkh3_par_file(self, tsp_file: str, file_handle):
        """Write LKH3 parameter file."""
        file_handle.write(f"PROBLEM_FILE = {tsp_file}\n")
        file_handle.write("MOVE_TYPE = 3\n")
        file_handle.write("PATCHING_C = 3\n")
        file_handle.write("PATCHING_A = 2\n")
        runs = getattr(self.config, 'lkh_runs', None) or getattr(self.config, 'runs', 1)
        if runs < 1:
            runs = 1
        file_handle.write(f"RUNS = {runs}\n")
        try:
            timeout_seconds = getattr(self.config, 'timeout_seconds', None) or getattr(self.config, 'oracle_timeout', 30)
            time_limit = max(5, int(timeout_seconds) - 3)
        except Exception:
            time_limit = 27  
        file_handle.write(f"TIME_LIMIT = {time_limit}\n")
        file_handle.write("TRACE_LEVEL = 0\n")  
    
    def _parse_lkh3_output(self, output: str) -> float:
        """Parse LKH3 output to extract tour length."""
        lines = output.split('\n')
        for line in lines:
            if 'Cost.min =' in line:
                # Extract the cost value from "Cost.min = 4000, Cost.avg = 4000.00, Cost.max = 4000"
                try:
                    # Find the cost value after "Cost.min ="
                    start_idx = line.find('Cost.min =')
                    if start_idx != -1:
                        # Extract the number after "Cost.min ="
                        cost_part = line[start_idx + len('Cost.min ='):]
                        cost_value = cost_part.split(',')[0].strip()
                        # Scale back from integer coordinates to original coordinates
                        # Since we multiplied by 1000, we need to divide by 1000
                        return float(cost_value) / 1000.0
                except (ValueError, IndexError):
                    continue
        # If parsing fails, return a large number
        return float('inf')
    
    def _solve_with_concorde(self, coords: np.ndarray) -> float:
        """Solve using Concorde (if available) with robust timeout and cleanup."""
        tsp_file = None
        sol_file = None
        process = None
        
        tmpdir = None
        if sys.platform != 'win32':
            if os.path.exists('/dev/shm'):
                tmpdir = '/dev/shm'
        
        try:
            # Create temporary TSPLIB file
            with tempfile.NamedTemporaryFile(mode='w', suffix='.tsp', delete=False, dir=tmpdir) as f:
                tsp_file = f.name
                self._write_tsplib_file(coords, f)
            
            # Concorde will create a .sol file with the same base name
            sol_file = tsp_file.replace('.tsp', '.sol')
            
            # Run Concorde with Popen for better process control
            concorde_path = getattr(self.config, 'concorde_path', None)
            concorde_cmd = concorde_path or "concorde"
            # start_new_session only works on Unix, and creates a new process group
            popen_kwargs = {
                'stdout': subprocess.PIPE,
                'stderr': subprocess.PIPE,
                'text': True,
                'cwd': os.path.dirname(tsp_file)  # Set working directory to temp dir
            }
            if sys.platform != 'win32':
                popen_kwargs['start_new_session'] = True  # Create new process group for cleanup on Unix
            
            # Concorde command: concorde input.tsp
            # It will create input.sol automatically
            process = subprocess.Popen([concorde_cmd, os.path.basename(tsp_file)], **popen_kwargs)
            t0 = time.time()
            
            # Wait for process with timeout
            try:
                timeout_seconds = getattr(self.config, 'timeout_seconds', None) or getattr(self.config, 'oracle_timeout', 30)
                stdout, stderr = process.communicate(timeout=timeout_seconds)
                returncode = process.returncode
                elapsed = time.time() - t0
                
                try:
                    from ...utils.evaluation_timer import get_timer
                    get_timer().add_lkh_oracle_time(elapsed)  
                except:
                    pass
                
                if returncode == 0:
                    tour_length = self._parse_concorde_output(tsp_file, sol_file, stdout, coords)
                    if np.isfinite(tour_length):
                        return tour_length
                    else:
                        print(f"        Concorde parsing failed, using nearest neighbor")
                        return self._nearest_neighbor_tour_length(coords)
                else:
                    print(f"        Concorde failed (returncode {returncode}): {stderr[:200] if stderr else 'no error message'}")
                    return self._nearest_neighbor_tour_length(coords)
                    
            except subprocess.TimeoutExpired:
                # Force kill the process and all its children
                timeout_seconds = getattr(self.config, 'timeout_seconds', None) or getattr(self.config, 'oracle_timeout', 30)
                print(f"        Concorde timeout after {timeout_seconds}s, killing process and using nearest neighbor")
                try:
                    if sys.platform != 'win32':
                        # On Unix: kill the entire process group
                        try:
                            pgid = os.getpgid(process.pid)
                            os.killpg(pgid, signal.SIGTERM)
                            # Wait a bit, then force kill if still running
                            try:
                                process.wait(timeout=1)
                            except subprocess.TimeoutExpired:
                                os.killpg(pgid, signal.SIGKILL)
                        except (ProcessLookupError, OSError) as e:
                            # Process already terminated or process group not found
                            # Try to kill just the process
                            try:
                                process.terminate()
                                process.wait(timeout=1)
                            except (subprocess.TimeoutExpired, ProcessLookupError):
                                try:
                                    process.kill()
                                except ProcessLookupError:
                                    pass  # Process already dead
                    else:
                        # On Windows: just kill the process
                        try:
                            process.terminate()
                            process.wait(timeout=1)
                        except (subprocess.TimeoutExpired, ProcessLookupError):
                            try:
                                process.kill()
                            except ProcessLookupError:
                                pass  # Process already dead
                except Exception:
                    # Any other error in cleanup - just try to kill process
                    try:
                        if process and process.poll() is None:
                            process.kill()
                    except Exception:
                        pass  # Ignore cleanup errors
                return self._nearest_neighbor_tour_length(coords)
                
        except FileNotFoundError as e:
            print(f"        Concorde executable not found: {e}, using nearest neighbor")
            return self._nearest_neighbor_tour_length(coords)
        except Exception as e:
            print(f"        Concorde unexpected error: {type(e).__name__}: {e}, using nearest neighbor")
            return self._nearest_neighbor_tour_length(coords)
        finally:
            # Always clean up temporary files
            for file_path in [tsp_file, sol_file]:
                if file_path and os.path.exists(file_path):
                    try:
                        os.unlink(file_path)
                    except OSError:
                        pass  # Ignore cleanup errors
    
    def _parse_concorde_output(self, tsp_file: str, sol_file: str, stdout: str, coords: np.ndarray) -> float:
        """Parse Concorde output to extract tour length."""
        # Method 1: Try to read from .sol file (tour order)
        if os.path.exists(sol_file):
            try:
                with open(sol_file, 'r') as f:
                    lines = f.readlines()
                    # Concorde .sol file format: 
                    # First line: dimension (number of cities)
                    # Second line: tour indices (0-based, space-separated)
                    tour = []
                    for i, line in enumerate(lines):
                        line = line.strip()
                        if line and not line.startswith('#'):
                            # First line is usually dimension, skip it
                            if i == 0:
                                # Check if it's just a number (dimension)
                                try:
                                    dim = int(line)
                                    if dim == len(coords):
                                        continue  # This is dimension line, skip
                                except ValueError:
                                    pass
                            
                            # Parse tour indices (can be space-separated on one line)
                            parts = line.split()
                            for part in parts:
                                try:
                                    idx = int(part)
                                    # Concorde uses 0-based indices in .sol file
                                    if 0 <= idx < len(coords):
                                        tour.append(idx)
                                except ValueError:
                                    continue
                
                if len(tour) > 0 and len(tour) == len(coords):
                    # Compute tour length from coordinates
                    total_length = 0.0
                    for i in range(len(tour)):
                        j = (i + 1) % len(tour)
                        total_length += np.linalg.norm(coords[tour[i]] - coords[tour[j]])
                    return total_length
            except Exception as e:
                pass  # Fall through to stdout parsing
        
        # Method 2: Try to parse from stdout
        # Concorde typically outputs cost information in stdout
        lines = stdout.split('\n')
        for line in lines:
            # Look for "Optimal Solution: X.XX" format
            if 'Optimal Solution:' in line:
                # Extract number after "Optimal Solution:"
                match = re.search(r'Optimal Solution:\s*([\d.]+)', line)
                if match:
                    try:
                        cost = float(match.group(1))
                        if 0 < cost < 1e10:
                            return cost
                    except ValueError:
                        pass
            # Also check for other cost formats
            elif 'cost' in line.lower() or 'length' in line.lower():
                # Try to extract numeric value
                numbers = re.findall(r'[\d.]+', line)
                if numbers:
                    try:
                        # Usually the first or last number is the cost
                        cost = float(numbers[-1])
                        # Check if it's reasonable (not too small or too large)
                        if 0 < cost < 1e10:
                            return cost
                    except ValueError:
                        continue
        
        # If all parsing fails, return infinity
        return float('inf')
    
    def compute_gap(self, heuristic_cost: float, optimal_cost: float) -> float:
        """
        Compute gap percentage.
        
        Args:
            heuristic_cost: Cost found by heuristic
            optimal_cost: Optimal cost
            
        Returns:
            Gap percentage: (heuristic_cost / optimal_cost - 1) * 100
        """
        if optimal_cost <= 0:
            return float('inf')
        return (heuristic_cost / optimal_cost - 1.0) * 100.0


# Factory function for easy oracle creation
def create_tsp_oracle(config=None, oracle_type: str = None, **kwargs) -> TSPOracle:
    """
    Create a TSP oracle instance.
    
    Args:
        config: TSPGLSConfig or HeuPSROConfig object (preferred). If provided, 
                oracle_type and other kwargs will override config values.
        oracle_type: Oracle type string (used if config is None or to override config)
        **kwargs: Additional config parameters (used if config is None or to override config)
    
    Returns:
        TSPOracle instance
    """
    if config is not None:
        if oracle_type is not None or kwargs:
            from .config import TSPGLSConfig
            config_dict = {}
            if hasattr(config, '__dict__'):
                config_dict = config.__dict__.copy()
            else:
                for attr in ['oracle_type', 'oracle_timeout', 'timeout_seconds', 'lkh3_path', 'concorde_path', 'lkh_runs', 'runs']:
                    if hasattr(config, attr):
                        config_dict[attr] = getattr(config, attr)
            
            if oracle_type is not None:
                config_dict['oracle_type'] = oracle_type
            config_dict.update(kwargs)
            if 'runs' in config_dict and 'lkh_runs' not in config_dict:
                config_dict['lkh_runs'] = config_dict.pop('runs')
            if 'timeout_seconds' in config_dict and 'oracle_timeout' not in config_dict:
                config_dict['oracle_timeout'] = config_dict.pop('timeout_seconds')
            temp_config = TSPGLSConfig(**config_dict)
            return TSPOracle(temp_config)
        else:
            return TSPOracle(config)
    else:
        from .config import TSPGLSConfig
        oracle_type = oracle_type or kwargs.get('oracle_type', 'none')
        kwargs_copy = kwargs.copy()
        if 'runs' in kwargs_copy and 'lkh_runs' not in kwargs_copy:
            kwargs_copy['lkh_runs'] = kwargs_copy.pop('runs')
        if 'timeout_seconds' in kwargs_copy and 'oracle_timeout' not in kwargs_copy:
            kwargs_copy['oracle_timeout'] = kwargs_copy.pop('timeout_seconds')
        temp_config = TSPGLSConfig(oracle_type=oracle_type, **kwargs_copy)
        return TSPOracle(temp_config)
