"""Shared helpers for TSP GLS testing utilities."""

from __future__ import annotations

import json
import os
import pickle
from dataclasses import dataclass
from typing import Dict, List, Optional

import numpy as np


@dataclass
class DatasetPack:
    coordinates: List[np.ndarray]
    distance_matrices: List[np.ndarray]
    optimal_costs: List[Optional[float]]


def load_tsplib_datasets(
    pkl_dir: str,
    max_instances: Optional[int] = None,
    max_size: Optional[int] = None,
    recursive: bool = True,
) -> Dict[str, DatasetPack]:
    """
    Load TSPLIB-style PKL files (coordinates, distance matrix, opt cost).

    """
    datasets: Dict[str, DatasetPack] = {}
    if not os.path.isdir(pkl_dir):
        raise FileNotFoundError(f"PKL directory not found: {pkl_dir}")

    pkl_files = []
    if recursive:
        for root, _, files in os.walk(pkl_dir):
            for fname in files:
                if fname.lower().endswith(".pkl"):
                    pkl_files.append((root, fname))
    else:
        for fname in os.listdir(pkl_dir):
            if fname.lower().endswith(".pkl"):
                pkl_files.append((pkl_dir, fname))

    pkl_files.sort(key=lambda x: os.path.join(x[0], x[1]))

    for root, fname in pkl_files:
        path = os.path.join(root, fname)
        try:
            with open(path, "rb") as f:
                data = pickle.load(f)
        except (pickle.UnpicklingError, EOFError, ValueError) as e:
            print(f"  Warning: Failed to load {path}: {e}. Skipping this file.")
            continue

        try:
            coords_list = data.get("coordinate") or []
            dist_list = data.get("distance_matrix") or []
            opt_list = data.get("cost")
            count = len(dist_list)
            if max_instances is not None:
                count = min(count, max_instances)

            # Check problem size (number of cities) from first instance
            if max_size is not None and len(coords_list) > 0:
                n_cities = len(coords_list[0]) if isinstance(coords_list[0], (list, np.ndarray)) else 0
                if n_cities > max_size:
                    print(f"  Skipping {path}: size {n_cities} > {max_size}")
                    continue

            coordinates = [np.asarray(coords_list[i], dtype=float) for i in range(count)]
            distance_matrices = [np.asarray(dist_list[i], dtype=float) for i in range(count)]
            optimal_costs: List[Optional[float]] = []
            for i in range(count):
                if opt_list is None or i >= len(opt_list):
                    optimal_costs.append(None)
                else:
                    val = opt_list[i]
                    optimal_costs.append(None if val is None else float(val))

            base_name = os.path.splitext(fname)[0]
            label = base_name

            datasets[label] = DatasetPack(
                coordinates=coordinates,
                distance_matrices=distance_matrices,
                optimal_costs=optimal_costs,
            )
        except (KeyError, IndexError, TypeError, ValueError) as e:
            print(f"  Warning: Failed to process {path}: {e}. Skipping this file.")
            continue

    return datasets


def save_table(out_dir: str, table: Dict[str, Dict[str, float]], basename: str) -> None:
    """Save per-solver tables into JSON/CSV."""
    os.makedirs(out_dir, exist_ok=True)
    json_path = os.path.join(out_dir, f"{basename}.json")
    csv_path = os.path.join(out_dir, f"{basename}.csv")

    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(table, f, indent=2)

    columns = sorted({col for row in table.values() for col in row.keys()})
    with open(csv_path, "w", encoding="utf-8") as f:
        f.write(",".join(["solver"] + columns) + "\n")
        for solver, row in table.items():
            values = []
            for col in columns:
                val = row.get(col)
                if val is None:
                    values.append("")
                else:
                    try:
                        values.append(f"{float(val):.6f}")
                    except (ValueError, TypeError):
                        values.append(str(val))
            f.write(",".join([str(solver)] + values) + "\n")


