"""TSP data augmentation distributions for initialization."""

from __future__ import annotations

import numpy as np


MIX_WEIGHTS = {
    "uniform_square": 0.30,
    "clustered_gaussian": 0.25,
    "grid_jitter": 0.15,
    "annulus_ring": 0.15,
    "two_scale_mixture": 0.15,
}

DISTRIBUTION_NAMES = list(MIX_WEIGHTS.keys())
DISTRIBUTION_WEIGHTS = list(MIX_WEIGHTS.values())


def get_da_generator_code() -> str:
    """
    generate generator code using data augmentation distributions.
    
    each time generate_instances is called:
    1. randomly select a distribution family to generate instances
    
    Returns:
        complete generator code string
    """
    return (
        "import numpy as np\n"
        "\n"
        "def generate_instances(seeds, n_cities):\n"
        "    instances = []\n"
        "    \n"
        "    distribution_names = ['uniform_square', 'clustered_gaussian', 'grid_jitter', 'annulus_ring', 'two_scale_mixture']\n"
        "    distribution_weights = [0.30, 0.25, 0.15, 0.15, 0.15]\n"
        "    \n"
        "    for seed in seeds:\n"
        "        np.random.seed(seed)\n"
        "        dist_type = np.random.choice(distribution_names, p=distribution_weights)\n"
        "        \n"
        "        if dist_type == 'uniform_square':\n"
        "            coords = generate_uniform_square(n_cities)\n"
        "        elif dist_type == 'clustered_gaussian':\n"
        "            coords = generate_clustered_gaussian(n_cities)\n"
        "        elif dist_type == 'grid_jitter':\n"
        "            coords = generate_grid_jitter(n_cities)\n"
        "        elif dist_type == 'annulus_ring':\n"
        "            coords = generate_annulus_ring(n_cities)\n"
        "        elif dist_type == 'two_scale_mixture':\n"
        "            coords = generate_two_scale_mixture(n_cities)\n"
        "        else:\n"
        "            coords = generate_uniform_square(n_cities)\n"
        "        \n"
        "        instances.append(coords)\n"
        "    \n"
        "    return instances\n"
        "\n"
        "\n"
        "def generate_uniform_square(n):\n"
        "    coords = np.random.rand(n, 2)\n"
        "    return coords\n"
        "\n"
        "\n"
        "def generate_clustered_gaussian(n):\n"
        "    k = np.random.randint(3, 9)  # k=3~8\n"
        "    \n"
        "    centers = np.random.rand(k, 2)\n"
        "    \n"
        "    sigma = np.random.uniform(0.02, 0.08)\n"
        "    \n"
        "    coords = []\n"
        "    for i in range(n):\n"
        "        cluster_idx = np.random.randint(0, k)\n"
        "        center = centers[cluster_idx]\n"
        "        point = center + np.random.normal(0, sigma, size=2)\n"
        "        \n"
        "        point = np.clip(point, 0.0, 1.0)\n"
        "        coords.append(point)\n"
        "    \n"
        "    return np.array(coords)\n"
        "\n"
        "\n"
        "def generate_grid_jitter(n):\n"
        "    grid_size = int(np.sqrt(n))\n"
        "    \n"
        "    x_grid = np.linspace(0.1, 0.9, grid_size)\n"
        "    y_grid = np.linspace(0.1, 0.9, grid_size)\n"
        "    \n"
        "    coords = []\n"
        "    for i in range(grid_size):\n"
        "        for j in range(grid_size):\n"
        "            if len(coords) >= n:\n"
        "                break\n"
        "            jitter = np.random.uniform(-0.05, 0.05, size=2)\n"
        "            point = np.array([x_grid[i], y_grid[j]]) + jitter\n"
        "            point = np.clip(point, 0.0, 1.0)\n"
        "            coords.append(point)\n"
        "        if len(coords) >= n:\n"
        "            break\n"
        "    \n"
        "    while len(coords) < n:\n"
        "        coords.append(np.random.rand(2))\n"
        "    \n"
        "    return np.array(coords[:n])\n"
        "\n"
        "\n"
        "def generate_annulus_ring(n):\n"
        "    r0 = np.random.uniform(0.2, 0.4)\n"
        "    r1 = np.random.uniform(0.5, 0.7)\n"
        "    \n"
        "    coords = []\n"
        "    for i in range(n):\n"
        "        theta = np.random.uniform(0, 2 * np.pi)\n"
        "        r = np.random.uniform(r0, r1)\n"
        "        \n"
        "        x = r * np.cos(theta)\n"
        "        y = r * np.sin(theta)\n"
        "        \n"
        "        x = x + 0.5\n"
        "        y = y + 0.5\n"
        "        \n"
        "        x = np.clip(x, 0.0, 1.0)\n"
        "        y = np.clip(y, 0.0, 1.0)\n"
        "        \n"
        "        coords.append([x, y])\n"
        "    \n"
        "    return np.array(coords)\n"
        "\n"
        "\n"
        "def generate_two_scale_mixture(n):\n"
        "    n_large = int(n * 0.7)\n"
        "    n_small = n - n_large\n"
        "    \n"
        "    coords = []\n"
        "    \n"
        "    for i in range(n_large):\n"
        "        coords.append(np.random.rand(2))\n"
        "    corner_x = np.random.choice([0.0, 1.0])\n"
        "    corner_y = np.random.choice([0.0, 1.0])\n"
        "    \n"
        "    cluster_size = np.random.uniform(0.05, 0.15)\n"
        "    \n"
        "    for i in range(n_small):\n"
        "        if corner_x == 0.0:\n"
        "            x = np.random.uniform(0.0, cluster_size)\n"
        "        else:\n"
        "            x = np.random.uniform(1.0 - cluster_size, 1.0)\n"
        "        \n"
        "        if corner_y == 0.0:\n"
        "            y = np.random.uniform(0.0, cluster_size)\n"
        "        else:\n"
        "            y = np.random.uniform(1.0 - cluster_size, 1.0)\n"
        "        \n"
        "        coords.append([x, y])\n"
        "    \n"
        "    coords = np.array(coords)\n"
        "    indices = np.random.permutation(n)\n"
        "    coords = coords[indices]\n"
        "    \n"
        "    return coords\n"
    )

































