"""BP Online data augmentation distributions for initialization."""

from __future__ import annotations

import numpy as np


# 分布族权重配置
MIX_WEIGHTS = {
    "base_weibull_clipped": 0.40,
    "heavy_tail": 0.20,
    "mixture_two_modes": 0.20,
    "near_capacity_spikes": 0.10,
    "discrete_clusters": 0.10,
}

DISTRIBUTION_NAMES = list(MIX_WEIGHTS.keys())
DISTRIBUTION_WEIGHTS = list(MIX_WEIGHTS.values())

# Arrival order类型
ORDER_TYPES = ["random", "ascending", "descending", "blocky", "almost_sorted"]


def get_da_generator_code() -> str:
    """
    生成使用data augmentation分布的generator代码。
    
    每次调用generate_instances时，按权重随机选择一个分布族来生成实例。
    同时会随机选择一个arrival order类型来打乱items的顺序。
    
    Returns:
        完整的generator代码字符串
    """
    return (
        "import numpy as np\n"
        "\n"
        "def generate_instances(seeds, capacity, num_items):\n"
        "    instances = []\n"
        "    \n"
        "    distribution_names = ['base_weibull_clipped', 'heavy_tail', 'mixture_two_modes', 'near_capacity_spikes', 'discrete_clusters']\n"
        "    distribution_weights = [0.40, 0.20, 0.20, 0.10, 0.10]\n"
        "    \n"
        "    order_types = ['random', 'ascending', 'descending', 'blocky', 'almost_sorted']\n"
        "    \n"
        "    for seed in seeds:\n"
        "        rng = np.random.default_rng(seed)\n"
        "        \n"
        "        dist_type = np.random.choice(distribution_names, p=distribution_weights)\n"
        "        \n"
        "        if dist_type == 'base_weibull_clipped':\n"
        "            items = generate_base_weibull_clipped(num_items, capacity, rng)\n"
        "        elif dist_type == 'heavy_tail':\n"
        "            items = generate_heavy_tail(num_items, capacity, rng)\n"
        "        elif dist_type == 'mixture_two_modes':\n"
        "            items = generate_mixture_two_modes(num_items, capacity, rng)\n"
        "        elif dist_type == 'near_capacity_spikes':\n"
        "            items = generate_near_capacity_spikes(num_items, capacity, rng)\n"
        "        elif dist_type == 'discrete_clusters':\n"
        "            items = generate_discrete_clusters(num_items, capacity, rng)\n"
        "        else:\n"
        "            items = generate_base_weibull_clipped(num_items, capacity, rng)\n"
        "        \n"
        "        order_type = np.random.choice(order_types)\n"
        "        items = apply_arrival_order(items, order_type, rng)\n"
        "        \n"
        "        # Return only items array; capacity and num_items will be added externally\n"
        "        instances.append(items)\n"
        "    \n"
        "    return instances\n"
        "\n"
        "\n"
        "def generate_base_weibull_clipped(n, capacity, rng):\n"
        "    shape = 1.4\n"
        "    scale = 30.0\n"
        "    raw = rng.weibull(shape, size=n) * scale\n"
        "    \n"
        "    # linear scale to [1, capacity-1], then round and convert to int\n"
        "    # first clip to reasonable range\n"
        "    raw_clipped = np.clip(raw, 0, capacity * 2)\n"
        "    # linear mapping to [1, capacity-1]\n"
        "    if raw_clipped.max() > raw_clipped.min():\n"
        "        items = 1 + (raw_clipped - raw_clipped.min()) / (raw_clipped.max() - raw_clipped.min()) * (capacity - 2)\n"
        "    else:\n"
        "        items = np.ones(n) * (capacity // 2)\n"
        "    \n"
        "    items = np.round(items).astype(int)\n"
        "    items = np.clip(items, 1, capacity - 1)\n"
        "    \n"
        "    return items\n"
        "\n"
        "\n"
        "def generate_heavy_tail(n, capacity, rng):\n"
        "    \"\"\"\n"
        "    heavy tail distribution (Lognormal)\n"
        "    - y ~ LogNormal(μ, σ), then normalize to [1, C-1]\n"
        "    - σ ∈ [0.8, 1.6] (heavier tail)\n"
        "    \"\"\"\n"
        "    # randomly select σ (control tail heaviness)\n"
        "    sigma = rng.uniform(0.8, 1.6)\n"
        "    # μ set to 0 (standard lognormal)\n"
        "    mu = 0.0\n"
        "    \n"
        "    # generate lognormal distribution\n"
        "    raw = rng.lognormal(mu, sigma, size=n)\n"
        "    \n"
        "    # normalize to [1, capacity-1]\n"
        "    if raw.max() > raw.min():\n"
        "        items = 1 + (raw - raw.min()) / (raw.max() - raw.min()) * (capacity - 2)\n"
        "    else:\n"
        "        items = np.ones(n) * (capacity // 2)\n"
        "    \n"
        "    items = np.round(items).astype(int)\n"
        "    items = np.clip(items, 1, capacity - 1)\n"
        "    \n"
        "    return items\n"
        "\n"
        "\n"
        "def generate_mixture_two_modes(n, capacity, rng):\n"
        "    \"\"\"\n"
        "    two-mode mixture distribution\n"
        "    - 50% small items: Uniform[1, 0.3C]\n"
        "    - 50% large items: Uniform[0.6C, 0.95C]\n"
        "    - This will force solver to learn 'reserve space/avoid碎片'\n"
        "    \"\"\"\n"
        "    n_small = n // 2\n"
        "    n_large = n - n_small\n"
        "    \n"
        "    items = []\n"
        "    \n"
        "    # 50% small items\n"
        "    small_items = rng.integers(1, int(0.3 * capacity) + 1, size=n_small)\n"
        "    items.extend(small_items)\n"
        "    \n"
        "    # 50% large items\n"
        "    large_items = rng.integers(int(0.6 * capacity), int(0.95 * capacity) + 1, size=n_large)\n"
        "    items.extend(large_items)\n"
        "    \n"
        "    items = np.array(items, dtype=int)\n"
        "    items = np.clip(items, 1, capacity - 1)\n"
        "    \n"
        "    return items\n"
        "\n"
        "\n"
        "def generate_near_capacity_spikes(n, capacity, rng):\n"
        "    \"\"\"\n"
        "    near capacity spikes distribution\n"
        "    - 80%：Uniform[1, 0.5C]\n"
        "    - 20%：Uniform[0.85C, 0.99C]\n"
        "    - It is essentially a 'rare large item interference', very online\n"
        "    \"\"\"\n"
        "    n_normal = int(n * 0.8)\n"
        "    n_spikes = n - n_normal\n"
        "    \n"
        "    items = []\n"
        "    \n"
        "    # 80% normal size\n"
        "    normal_items = rng.integers(1, int(0.5 * capacity) + 1, size=n_normal)\n"
        "    items.extend(normal_items)\n"
        "    \n"
        "    # 20% spikes (near capacity)\n"
        "    spike_items = rng.integers(int(0.85 * capacity), int(0.99 * capacity) + 1, size=n_spikes)\n"
        "    items.extend(spike_items)\n"
        "    \n"
        "    items = np.array(items, dtype=int)\n"
        "    items = np.clip(items, 1, capacity - 1)\n"
        "    \n"
        "    return items\n"
        "\n"
        "\n"
        "def generate_discrete_clusters(n, capacity, rng):\n"
        "    \"\"\"\n"
        "    discrete cluster distribution\n"
        "    - First sample K cluster centers (e.g., K=5), then sample each item from a center ± noise\n"
        "    - This will force solver to handle 'repeated size/structural input'\n"
        "    \"\"\"\n"
        "    # randomly select cluster number\n"
        "    k = rng.integers(3, 8)  # K=3~7\n"
        "    \n"
        "    # generate K cluster centers (in [1, capacity-1] range)\n"
        "    centers = rng.integers(1, capacity, size=k)\n"
        "    \n"
        "    items = []\n"
        "    for i in range(n):\n"
        "        # randomly select a cluster\n"
        "        cluster_idx = rng.integers(0, k)\n"
        "        center = centers[cluster_idx]\n"
        "        \n"
        "        # generate point from the cluster center (add noise)\n"
        "        # noise range: ±10% of capacity\n"
        "        noise_range = max(1, int(0.1 * capacity))\n"
        "        noise = rng.integers(-noise_range, noise_range + 1)\n"
        "        \n"
        "        item = center + noise\n"
        "        item = np.clip(item, 1, capacity - 1)\n"
        "        items.append(item)\n"
        "    \n"
        "    items = np.array(items, dtype=int)\n"
        "    \n"
        "    return items\n"
        "\n"
        "\n"
        "def apply_arrival_order(items, order_type, rng):\n"
        "    \"\"\"\n"
        "    apply arrival order augmentation\n"
        "    \n"
        "    Args:\n"
        "        items: item sizes array\n"
        "        order_type: order type ('random', 'ascending', 'descending', 'blocky', 'almost_sorted')\n"
        "        rng: numpy random generator\n"
        "    \n"
        "    Returns:\n"
        "        reordered items array\n"
        "    \"\"\"\n"
        "    if order_type == 'random':\n"
        "        # randomly shuffle\n"
        "        indices = rng.permutation(len(items))\n"
        "        return items[indices]\n"
        "    \n"
        "    elif order_type == 'ascending':\n"
        "        # from small to large\n"
        "        return np.sort(items)\n"
        "    \n"
        "    elif order_type == 'descending':\n"
        "        # from large to small\n"
        "        return np.sort(items)[::-1]\n"
        "    \n"
        "    elif order_type == 'blocky':\n"
        "        # first large then small then large (人为制造阶段性)\n"
        "        sorted_items = np.sort(items)[::-1]  # from large to small sort\n"
        "        n = len(sorted_items)\n"
        "        \n"
        "        # split into 3 blocks: large-small-large\n"
        "        n1 = n // 3\n"
        "        n2 = n // 3\n"
        "        n3 = n - n1 - n2\n"
        "        \n"
        "        block1 = sorted_items[:n1]  # large\n"
        "        block2 = sorted_items[n1:n1+n2][::-1]  # small (reversed)\n"
        "        block3 = sorted_items[n1+n2:]  # large\n"
        "        \n"
        "        return np.concatenate([block1, block2, block3])\n"
        "    \n"
        "    elif order_type == 'almost_sorted':\n"
        "        # first sort then do少量swap (more realistic flow)\n"
        "        sorted_items = np.sort(items)\n"
        "        \n"
        "        # randomly select some positions to swap\n"
        "        n_swaps = max(1, len(items) // 10)  # 10% of positions to swap\n"
        "        \n"
        "        for _ in range(n_swaps):\n"
        "            i = rng.integers(0, len(sorted_items))\n"
        "            j = rng.integers(0, len(sorted_items))\n"
        "            if i != j:\n"
        "                sorted_items[i], sorted_items[j] = sorted_items[j], sorted_items[i]\n"
        "        \n"
        "        return sorted_items\n"
        "    \n"
        "    else:\n"
        "        # default: randomly shuffle\n"
        "        indices = rng.permutation(len(items))\n"
        "        return items[indices]\n"
    )








