"""TSP-specific initialization logic."""

from __future__ import annotations

from .data_augmentation import get_da_generator_code


def get_eoh_optimal_code() -> str:
    """
    Get the EoH paper's optimal heuristic code for TSP.
    
    This is the best-performing heuristic from the EoH paper.
    Note: The original EoH code uses function name 'heuristic', but we adapt it
    to 'update_edge_distance' to match our interface.
    """
    return (
        "import numpy as np\n"
        "def update_edge_distance(edge_distance, local_opt_tour, edge_n_used):\n"
        "    updated_edge_distance = np.copy(edge_distance)\n"
        "    edge_count = np.zeros_like(edge_distance)\n"
        "    for i in range(len(local_opt_tour) - 1):\n"
        "        start = local_opt_tour[i]\n"
        "        end = local_opt_tour[i + 1]\n"
        "        edge_count[start][end] += 1\n"
        "        edge_count[end][start] += 1\n"
        "    # penalize local optimal route\n"
        "    edge_n_used_max = np.max(edge_n_used)\n"
        "    # calculate the average edge used\n"
        "    decay_factor = 0.1  # decay fast or\n"
        "    mean_distance = np.mean(edge_distance)\n"
        "    # calculate the average distance\n"
        "    # Protect against division by zero\n"
        "    if edge_n_used_max == 0:\n"
        "        edge_n_used_max = 1.0  # Avoid division by zero\n"
        "    if mean_distance == 0:\n"
        "        mean_distance = 1.0  # Avoid division by zero\n"
        "    for i in range(edge_distance.shape[0]):\n"
        "        for j in range(edge_distance.shape[1]):\n"
        "            if edge_count[i][j] > 0:\n"
        "                noise_factor = (np.random.uniform(0.7, 1.3) / edge_count[i][j]) + (\n"
        "                    edge_distance[i][j] / mean_distance) - (0.3 / edge_n_used_max) * edge_n_used[i][j]\n"
        "                # calculate a hybrid noise factor\n"
        "                updated_edge_distance[i][j] += noise_factor * (1 + edge_count[i][j]) - decay_factor * updated_edge_distance[i][j]\n"
        "    # The new guiding edge distance matrix is calculated based on both a noise term and a decayed original distance matrix\n"
        "    return updated_edge_distance\n"
    )


def initialize_tsp_strategies(controller) -> None:
    """
    Initialize TSP-specific strategy pools.
    
    Creates:
    - Basic heuristic solver (h0)
    - Uniform random generator (g0)
    
    Args:
        controller: HeuPSROController instance
    """
    print(" Initializing HeuPSRO for TSP...")
    
    # Create basic heuristic solver
    basic_heuristic = (
        "def update_edge_distance(edge_distance, local_opt_tour, edge_n_used):\n"
        "    return edge_distance.copy()\n"
    )
    
    h_idx = controller.pools.add_solver(
        "h0", basic_heuristic, "Basic heuristic", {}, {"source": "init"}
    ) 
    # Create generator based on da flag
    if getattr(controller.cfg, 'da', False):
        # Use data augmentation distributions
        generator_code = get_da_generator_code()
        generator_name = "Data augmentation generator"
    else:
        # Use uniform random generator (default)
        generator_code = (
            "def generate_instances(seeds, n_cities):\n"
            "    import numpy as np\n"
            "    instances = []\n"
            "    for seed in seeds:\n"
            "        np.random.seed(seed)\n"
            "        coords = np.random.rand(n_cities, 2)\n"
            "        instances.append(coords)\n"
            "    return instances\n"
        )
        generator_name = "Uniform generator"
    
    g_idx = controller.pools.add_generator(
        "g0",
        generator_code,
        generator_name,
        {"n_cities": controller.cfg.n_cities},
        {"min_ratio": controller.cfg.min_simple_ratio, "source": "init"}
    )

    if getattr(controller.cfg, 'da', False):
        print("   Initialized with basic heuristic and data augmentation generator")
    else:
        print("   Initialized with basic heuristic and uniform generator")

