import numpy as np
import time
import os
import re
import json
import math

# --- Helper Functions for Data Loading ---

def read_distance_matrix_from_file(file_path: str) -> np.ndarray:
    """
    Reads a pre-computed distance matrix from a .txt file.
    This function skips non-numeric lines, such as headers or the optimal distance line.
    """
    matrix_rows = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                # Skip empty lines or lines that do not start with a digit
                if not line or not line[0].isdigit():
                    continue
                
                try:
                    row = [float(num) for num in line.split()]
                    if row:
                        matrix_rows.append(row)
                except ValueError:
                    # Skip lines that cannot be fully converted to numbers
                    continue
    except FileNotFoundError:
        raise

    if not matrix_rows:
        raise ValueError(f"Could not parse any matrix data from {os.path.basename(file_path)}.")

    return np.array(matrix_rows)


def parse_optimal_distance_from_file(file_path: str) -> float | None:
    """
    Parses a pre-saved optimal distance from a file using regex.
    It looks for a line formatted like 'Total Distance: 1234.56'.
    """
    pattern = re.compile(r"Total Distance:\s*([\d.]+)")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in reversed(f.readlines()):
                match = pattern.search(line)
                if match:
                    return float(match.group(1))
    except (FileNotFoundError, IOError):
        return None
    return None

# --- Main TSP Solver (Dynamic Programming) ---

def solve_tsp_dp(dist_matrix: np.ndarray):
    """
    Finds an exact solution to the TSP for a given distance matrix
    using the dynamic programming Held-Karp algorithm.
    
    Returns a tuple of (tour, cost).
    """
    n = dist_matrix.shape[0]

    # Practicality Check: DP is O(n^2 * 2^n), infeasible for n > ~20.
    if n > 700:
        print(f"   -> ⚠️ Warning: Problem size n={n} is too large for DP. Skipping.")
        return [], float('inf')
    
    # Handle base cases
    if n == 0:
        return [], 0.0
    if n == 1:
        return [0], 0.0

    distances = dist_matrix.tolist()
    
    # dp[mask][last] = shortest path visiting cities in 'mask' ending at 'last'
    dp = [[math.inf] * n for _ in range(1 << n)]
    parent = [[None] * n for _ in range(1 << n)]

    # Base case: start at city 0. The mask '1' represents the set {0}.
    dp[1][0] = 0

    # Fill DP table by iterating through all subsets of cities
    for mask in range(1, 1 << n):
        for last in range(n):
            # Skip states that are unreachable
            if dp[mask][last] == math.inf:
                continue

            # Ensure 'last' is actually in the 'mask' set
            if not (mask & (1 << last)):
                continue

            # Try to extend the path to a 'next_city'
            for next_city in range(n):
                # If 'next_city' is not yet visited (i.e., not in 'mask')
                if not (mask & (1 << next_city)):
                    new_mask = mask | (1 << next_city)
                    new_dist = dp[mask][last] + distances[last][next_city]

                    if new_dist < dp[new_mask][next_city]:
                        dp[new_mask][next_city] = new_dist
                        parent[new_mask][next_city] = last
    
    # Find the optimal tour cost by connecting the end of the path back to the start (city 0)
    min_cost = math.inf
    end_city = None
    full_mask = (1 << n) - 1

    # Find the best tour by checking all possible end cities (that connect back to 0)
    for last in range(1, n):
        cost = dp[full_mask][last] + distances[last][0]
        if cost < min_cost:
            min_cost = cost
            end_city = last
            
    # If no valid tour was found (e.g., for n=2 this logic is needed)
    if end_city is None:
        if n == 2:
            min_cost = distances[0][1] + distances[1][0]
            tour = [0, 1]
            return tour, min_cost
        return [], math.inf # No tour found for n > 2

    # Reconstruct the optimal tour path
    tour = []
    mask = full_mask
    last = end_city
    while last is not None:
        tour.append(last)
        prev_last = parent[mask][last]
        mask ^= (1 << last) # Remove 'last' from mask to find its parent's state
        last = prev_last
    
    tour = tour[::-1] # Reverse to get the correct path order from start to end

    return tour, min_cost


# --- Main Execution Block ---

if __name__ == "__main__":
    # ❗ **SET YOUR FOLDER PATH HERE**
    folder_path = 'all'
    
    if not os.path.isdir(folder_path):
        print(f"❌ Error: Folder not found at '{folder_path}'")
        exit()

    errors = []
    results_data = [] # List to store results for JSON output
    total_start_time = time.time()

    print(f"🚀 Scanning for .txt files in '{os.path.abspath(folder_path)}'...")
    
    # Filter out result files from previous runs
    problem_files = sorted([f for f in os.listdir(folder_path) if f.endswith(".txt") and "result" not in f])
    
    if not problem_files:
        print("⚠️ No .txt problem files found in the specified directory.")
        exit()

    # --- Main Processing Loop ---
    for filename in problem_files:
        file_path = os.path.join(folder_path, filename)
        print("\n" + "-"*50)
        print(f"Processing: {filename}")
        
        try:
            # Step 1: Get the known optimal result from the file
            optimal_result = parse_optimal_distance_from_file(file_path)
            if optimal_result is None:
                print("   -> ⚠️ Warning: No saved optimal result found in file. Skipping.")
                continue

            # Step 2: Read the distance matrix directly from the file
            dist_matrix = read_distance_matrix_from_file(file_path)
            
            # Step 3: Run the DP solver to get the exact result
            _tour, dp_result = solve_tsp_dp(dist_matrix)
            
            # If the problem was skipped due to its size, continue
            if dp_result == float('inf'):
                continue

            # Step 4: Calculate the error and store it
            if optimal_result > 1e-9:
                error = (dp_result - optimal_result) / optimal_result
            else:
                error = dp_result - optimal_result # Use absolute error if optimal is zero
                
            errors.append(error)
            
            # Store results for JSON output
            result_entry = {
                "filename": filename,
                "optimal_distance": optimal_result,
                "calculated_distance": dp_result,
                "error": error
            }
            results_data.append(result_entry)
            
            print(f"   -> Optimal Result (from file): {optimal_result:.4f}")
            print(f"   -> DP Result (this code):      {dp_result:.4f}")
            # DP is an exact algorithm, so the error should be near zero, subject to float precision
            print(f"   -> Error: {error:.4e} ({error:.4%})")

        except Exception as e:
            print(f"   -> ❌ An error occurred: {e}")

    # --- Save results to JSON file ---
    json_output_path = os.path.join(folder_path, 'dp_results.json')
    print("\n" + "-"*50)
    print(f"💾 Saving detailed results to {json_output_path}...")
    try:
        with open(json_output_path, 'w', encoding='utf-8') as f:
            json.dump(results_data, f, indent=4)
        print("   -> ✅ Successfully saved.")
    except Exception as e:
        print(f"   -> ❌ Failed to save JSON file: {e}")

    # --- Final Summary ---
    total_execution_time = time.time() - total_start_time
    
    print("\n" + "="*50)
    print("🎉 All files processed. Final Report:")
    print("="*50)
    
    if errors:
        average_error = np.mean(errors)
        print(f"📂 Files Processed: {len(errors)}")
        print(f"📊 **Average Error:** **{average_error:.4e}**")
    else:
        print("No files were successfully processed to calculate an average error.")
        
    print(f"⏱️ **Total Execution Time:** **{total_execution_time:.4f} seconds**")