"""BP Online-specific initialization logic."""

from __future__ import annotations

import numpy as np

from .data_augmentation import get_da_generator_code


def get_eoh_optimal_code() -> str:
    """
    Get the EoH paper's optimal heuristic code for bin packing.
    
    This is the best-performing heuristic from the EoH paper.
    """
    return (
        "import numpy as np\n"
        "def score(item, bins):\n"
        "    diff = bins - item  # remaining capacity\n"
        "    exp = np.exp(diff)  # exponent term\n"
        "    sqrt = np.sqrt(diff)  # square root term\n"
        "    ulti = 1 - diff / bins  # utilization term\n"
        "    comb = ulti * sqrt  # combination of utilization and square root\n"
        "    adjust = np.where(diff > (item * 3), comb + 0.8, comb + 0.3)\n"
        "    # hybrid adjustment term to penalize large bins\n"
        "    hybrid_exp = bins / ((exp + 0.7) * exp)\n"
        "    # hybrid score based on exponent term\n"
        "    scores = hybrid_exp + adjust\n"
        "    # sum of hybrid score and adjustment\n"
        "    return scores\n"
    )

def get_best_fit_code() -> str:
    """
    Get Best Fit baseline algorithm code for bin packing.
    
    Best Fit: Places each item in the bin with the smallest remaining capacity
    that can still accommodate the item.
    """
    return (
        "import numpy as np\n"
        "def score(item, bins):\n"
        "    # Best Fit: prefer bins with smallest remaining capacity that can fit the item\n"
        "    # We want to minimize remaining capacity after placing the item\n"
        "    # So we score by negative remaining capacity (larger score = smaller remaining capacity)\n"
        "    remaining = bins - item  # remaining capacity after placing item\n"
        "    # Only consider bins that can fit the item (remaining >= 0)\n"
        "    valid_mask = remaining >= 0\n"
        "    # For valid bins, score is negative remaining capacity (so smallest remaining gets highest score)\n"
        "    # For invalid bins, score is negative infinity\n"
        "    scores = np.where(valid_mask, -remaining, -np.inf)\n"
        "    return scores\n"
    )


def get_first_fit_code() -> str:
    """
    Get First Fit baseline algorithm code for bin packing.
    
    First Fit: Places each item in the first bin that can accommodate it.
    """
    return (
        "import numpy as np\n"
        "def score(item, bins):\n"
        "    # First Fit: prefer the first bin that can fit the item\n"
        "    # We score by index position (earlier bins get higher scores)\n"
        "    remaining = bins - item  # remaining capacity after placing item\n"
        "    valid_mask = remaining >= 0\n"
        "    # Create scores based on position: first valid bin gets highest score\n"
        "    # Use (len(bins) - index) so earlier bins get higher scores\n"
        "    indices = np.arange(len(bins))\n"
        "    # For valid bins: score = len(bins) - index (so index 0 gets highest score)\n"
        "    # For invalid bins: score is negative infinity\n"
        "    scores = np.where(valid_mask, len(bins) - indices, -np.inf)\n"
        "    return scores\n"
    )


def get_worst_fit_code() -> str:
    """
    Get Worst Fit baseline algorithm code for bin packing.
    
    Worst Fit: Places each item in the bin with the largest remaining capacity
    that can still accommodate the item.
    """
    return (
        "import numpy as np\n"
        "def score(item, bins):\n"
        "    # Worst Fit: prefer bins with largest remaining capacity that can fit the item\n"
        "    # We want to maximize remaining capacity after placing the item\n"
        "    # So we score by remaining capacity (larger remaining = higher score)\n"
        "    remaining = bins - item  # remaining capacity after placing item\n"
        "    # Only consider bins that can fit the item (remaining >= 0)\n"
        "    valid_mask = remaining >= 0\n"
        "    # For valid bins, score is remaining capacity (so largest remaining gets highest score)\n"
        "    # For invalid bins, score is negative infinity\n"
        "    scores = np.where(valid_mask, remaining, -np.inf)\n"
        "    return scores\n"
    )


def get_weibull_generator_code(capacity: int, scale_ratio: float = 0.3, shape: float = 1.4) -> str:
    """
    Return Python code that generates Online Bin Packing instances
    using a Weibull(shape, scale) distribution.

    Args:
        capacity: Bin capacity (used to compute scale)
        scale_ratio: Scale ratio relative to capacity (scale = capacity * scale_ratio)
        shape: Weibull shape parameter k (fixed at 1.4, controls long-tail / small-item proportion)

    Returns:
        A string containing the generator code.
    """
    scale = capacity * scale_ratio
    return (
        "import numpy as np\n"
        "def generate_instances(seeds, capacity, num_items):\n"
        "    instances = []\n"
        "    for seed in seeds:\n"
        "        rng = np.random.default_rng(seed)\n"
        f"        # Generate items using Weibull(shape={shape}, scale={scale})\n"
        f"        raw = rng.weibull({shape}, size=num_items) * {scale}\n"
        "        \n"
        "        # Ensure items fit into bins\n"
        "        # (Online BP requires each item < capacity)\n"
        "        items = np.clip(raw, 1, capacity - 1).astype(int)\n"
        "        \n"
        "        # Return only items array; capacity and num_items will be added externally\n"
        "        instances.append(items)\n"
        "    return instances\n"
    )


def initialize_bp_online_strategies(controller) -> None:
    """
    Initialize BP Online-specific strategy pools.
    
    Creates:
    - Basic heuristic solver (h0)
    - Weibull distribution generator (g0)
    
    Args:
        controller: HeuPSROController instance
    """
    print(" Initializing HeuPSRO for BP Online...")
    
    # Create basic heuristic solver using Best Fit algorithm
    # Best Fit: Places each item in the bin with the smallest remaining capacity
    basic_heuristic = get_best_fit_code()
    
    h_idx = controller.pools.add_solver(
        "h0", basic_heuristic, "Basic heuristic", {}, {"source": "init"}
    )
    
    # Create generator based on da flag
    capacity = controller.cfg.capacity
    num_items = controller.cfg.num_items
    
    if getattr(controller.cfg, 'da', False):
        # Use data augmentation distributions
        generator_code = get_da_generator_code()
        generator_name = "Data augmentation generator"
    else:
        # Use Weibull distribution generator (default)
        # scale = capacity * 0.3, shape = 1.4 (fixed)
        generator_code = get_weibull_generator_code(capacity)
        generator_name = "Weibull generator"
    
    g_idx = controller.pools.add_generator(
        "g0",
        generator_code,
        generator_name,
        {"capacity": capacity, "num_items": num_items},
        {"min_ratio": controller.cfg.min_simple_ratio, "source": "init"}
    )
    
    if getattr(controller.cfg, 'da', False):
        print("   Initialized with basic heuristic and data augmentation generator")
    else:
        print("   Initialized with basic heuristic and Weibull generator")

