"""
Evaluates a modified Christofides algorithm for the Traveling Salesperson Problem (TSP)
by comparing its results against known optimal solutions from a folder of .txt files
containing pre-computed distance matrices.
"""
import numpy as np
from collections import defaultdict
import time
import os
import re
import json # Added for JSON output

# --- Helper Functions for Data Loading ---

def read_distance_matrix_from_file(file_path: str) -> np.ndarray:
    """
    Reads a pre-computed distance matrix from a file.
    It assumes the file contains rows of numbers, ignoring non-numeric lines.
    """
    matrix_rows = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                # Skip empty lines, headers, or the optimal distance line.
                # A simple check is to see if the line starts with a digit.
                if not line or not line[0].isdigit():
                    continue

                # Parse the line as a row of numbers
                try:
                    row = [float(num) for num in line.split()]
                    if row:
                        matrix_rows.append(row)
                except ValueError:
                    # Skip lines that cannot be converted to numbers
                    continue
    except FileNotFoundError:
        raise

    if not matrix_rows:
        raise ValueError(f"Could not parse any matrix data from {os.path.basename(file_path)}.")

    # Basic validation to check if the matrix is square
    num_rows = len(matrix_rows)
    if any(len(row) != num_rows for row in matrix_rows):
        raise ValueError("Parsed matrix is not square. Please check file format.")

    return np.array(matrix_rows)


def parse_optimal_distance_from_file(file_path: str) -> float | None:
    """
    Parses a pre-saved optimal distance from a file using regex.
    It looks for a line formatted like 'Total Distance: 1234.56'.
    """
    pattern = re.compile(r"Total Distance:\s*([\d.]+)")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in reversed(f.readlines()):
                match = pattern.search(line)
                if match:
                    return float(match.group(1))
    except (FileNotFoundError, IOError):
        return None
    return None

# --- Main TSP Solver (Christofides Heuristic) ---

def solve_tsp_approximate(dist_matrix):
    """
    Finds an approximate solution to the TSP for a given distance matrix
    using a modified Christofides algorithm with a greedy matching approach.
    """
    n = dist_matrix.shape[0]
    if n < 2:
        return np.array(range(n)), 0.0

    # Step 1: Compute MST using Prim's algorithm
    key = [float('inf')] * n
    parent = [-1] * n
    in_mst = [False] * n
    key[0] = 0.0

    for _ in range(n):
        min_key = float('inf')
        u = -1
        for v_idx in range(n):
            if not in_mst[v_idx] and key[v_idx] < min_key:
                min_key = key[v_idx]
                u = v_idx
        if u == -1:
            break
        in_mst[u] = True
        for v_idx in range(n):
            if not in_mst[v_idx] and dist_matrix[u, v_idx] < key[v_idx]:
                key[v_idx] = dist_matrix[u, v_idx]
                parent[v_idx] = u

    mst_graph = defaultdict(list)
    degree = np.zeros(n, dtype=int)
    for i in range(1, n):
        u, v = parent[i], i
        mst_graph[u].append(v)
        mst_graph[v].append(u)
        degree[u] += 1
        degree[v] += 1

    # Step 2: Find nodes with odd degree
    odd_degree_nodes = [i for i in range(n) if degree[i] % 2 != 0]

    # Step 3: Greedy matching on odd_degree_nodes
    matching_edges = []
    unmatched = list(odd_degree_nodes)
    while unmatched:
        u = unmatched.pop(0)
        closest_dist = float('inf')
        closest_v = -1
        for v in unmatched:
            if dist_matrix[u, v] < closest_dist:
                closest_dist = dist_matrix[u, v]
                closest_v = v
        if closest_v != -1:
            matching_edges.append((u, closest_v))
            unmatched.remove(closest_v)

    # Step 4: Build multigraph
    multigraph_adj = [list(neighbors) for _, neighbors in sorted(mst_graph.items())]
    for u, v in matching_edges:
        multigraph_adj[u].append(v)
        multigraph_adj[v].append(u)

    # Step 5: Find an Eulerian circuit
    stack = [0]
    circuit = []
    adj_copy = [list(neighbors) for neighbors in multigraph_adj]

    while stack:
        current = stack[-1]
        if adj_copy[current]:
            next_node = adj_copy[current].pop(0)
            adj_copy[next_node].remove(current)
            stack.append(next_node)
        else:
            circuit.append(stack.pop())
    circuit.reverse()

    # Step 6: Create Hamiltonian cycle by shortcutting
    visited = set()
    tour = [node for node in circuit if not (node in visited or visited.add(node))]

    # Calculate total distance
    total_distance = sum(dist_matrix[tour[i], tour[(i + 1) % n]] for i in range(len(tour)))

    return np.array(tour), total_distance


# --- Main Execution Block ---

if __name__ == "__main__":
    # ❗ **SET YOUR FOLDER PATH HERE**
    folder_path = 'all'

    if not os.path.isdir(folder_path):
        print(f"❌ Error: Folder not found at '{folder_path}'")
        exit()

    errors = []
    results_data = [] # New list to store results for JSON output
    total_start_time = time.time()

    # Modified to look for .txt files
    print(f"🚀 Scanning for .txt files in '{os.path.abspath(folder_path)}'...")

    # Get a sorted list of .txt files to process
    problem_files = sorted([f for f in os.listdir(folder_path) if f.endswith(".txt")])

    if not problem_files:
        # Modified to mention .txt files
        print("⚠️ No .txt files found in the specified directory.")
        exit()

    # --- Main Processing Loop ---
    for filename in problem_files:
        file_path = os.path.join(folder_path, filename)
        print("\n" + "-"*50)
        print(f"Processing: {filename}")

        try:
            # Step 1: Get the known optimal result from the file
            optimal_result = parse_optimal_distance_from_file(file_path)
            if optimal_result is None:
                print("   -> ⚠️ Warning: No saved optimal result found in file. Skipping.")
                continue

            # Step 2: Read the distance matrix directly from the file
            dist_matrix = read_distance_matrix_from_file(file_path)

            # Step 3: Run your heuristic to get a new result
            _, heuristic_result = solve_tsp_approximate(dist_matrix)

            # Step 4: Calculate the error and store it
            error = (heuristic_result - optimal_result) / optimal_result
            errors.append(error)

            # Store results in a dictionary for JSON output
            result_entry = {
                "filename": filename,
                "optimal_distance": optimal_result,
                "heuristic_distance": heuristic_result,
                "error": error
            }
            results_data.append(result_entry)

            print(f"   -> Optimal Result (from file): {optimal_result:.4f}")
            print(f"   -> Heuristic Result (this code): {heuristic_result:.4f}")
            print(f"   -> Error: {error:.2%}")

        except Exception as e:
            print(f"   -> ❌ An error occurred: {e}")

    # --- Save results to JSON file ---
    json_output_path = os.path.join(folder_path, 'Christofides_result.json')
    print("\n" + "-"*50)
    print(f"💾 Saving detailed results to {json_output_path}...")
    try:
        with open(json_output_path, 'w', encoding='utf-8') as f:
            json.dump(results_data, f, indent=4)
        print("   -> ✅ Successfully saved.")
    except Exception as e:
        print(f"   -> ❌ Failed to save JSON file: {e}")


    # --- Final Summary ---
    total_execution_time = time.time() - total_start_time

    print("\n" + "="*50)
    print("🎉 All files processed. Final Report:")
    print("="*50)

    if errors:
        average_error = np.mean(errors)
        print(f"📂 Files Processed: {len(errors)}")
        print(f"📊 **Average Error:** **{average_error:.2%}**")
    else:
        print("No files were successfully processed to calculate an average error.")

    print(f"⏱️ **Total Execution Time:** **{total_execution_time:.4f} seconds**")