"""Sampling utilities."""
import math
import random
import numpy as np
from typing import List, TypeVar, Callable, Optional

import logging

logger = logging.getLogger(__name__)
T = TypeVar('T')


def fix_sample(ops: List, seed: int = 2026, k: int = 50) -> List:
    """Randomly sample k items from ops with fixed seed.
    
    Args:
        ops: List of items to sample from
        seed: Random seed
        k: Number of items to sample
        
    Returns:
        Sampled list
    """
    random.seed(seed)
    if k <= len(ops):
        random_ops = random.sample(ops, k)
    else:
        random_ops = random.sample(ops, len(ops))
    return random_ops


def softmax_sample(exp_notes: List, k: int, temperature: float = 0.1) -> List:
    """Sample k items from exp_notes using softmax sampling based on scores.
    
    Args:
        exp_notes: List of experience notes with scores
        k: Number of items to sample
        temperature: Temperature for softmax (lower = more deterministic)
        
    Returns:
        Sampled list of experience notes
    """
    if not exp_notes:
        return []
    k = min(k, len(exp_notes))
    scores = np.array([getattr(n, "score", 0.0) for n in exp_notes], dtype=np.float64)
    t = max(temperature, 1e-6)
    scaled = scores / t
    scaled = scaled - scaled.max()  # Prevent overflow
    probs = np.exp(scaled)
    probs = probs / probs.sum()
    indices = np.random.choice(len(exp_notes), size=k, replace=False, p=probs)
    return [exp_notes[i] for i in indices]


def epsilon_greedy_select(
    sorted_candidates: List[T], 
    top_k: int, 
    epsilon: float
) -> List[T]:
    """ε-greedy selection strategy.
    
    Select top_k elements from sorted candidate list:
    - With epsilon probability: random exploration (random sampling from all candidates)
    - With (1-epsilon) probability: exploitation (select top top_k from sorted list)
    
    Args:
        sorted_candidates: Candidate list sorted by some metric (high to low)
        top_k: Number of elements to select
        epsilon: Exploration probability, range [0, 1]
        
    Returns:
        Selected element list (count <= top_k)
    """
    if not sorted_candidates:
        return []
    
    k = min(top_k, len(sorted_candidates))
    
    if random.random() < epsilon:
        # Exploration: random sampling from candidate list
        logger.info(f"[epsilon_greedy_select] Exploration: Selected from random: {[n.id for n in sorted_candidates[:k]]}")
        selected = random.sample(sorted_candidates, k)
    else:
        # Exploitation: select top top_k from sorted list (highest Q values)
        logger.info(f"[epsilon_greedy_select] Exploitation: Selected from sorted: {[n.id for n in sorted_candidates[:k]]}")
        selected = sorted_candidates[:k]
    
    return selected


def boltzmann_select(
    sorted_candidates: List[T],
    top_k: int,
    temperature: float,
    get_value: Optional[Callable[[T], float]] = None,
    default_value: float = 0.0
) -> List[T]:
    """Boltzmann exploration selection strategy.
    
    Use softmax distribution for probabilistic sampling based on values (e.g., Q values), temperature parameter controls exploration degree:
    - High temperature: more uniform distribution (more exploration)
    - Low temperature: more concentrated on high values (more exploitation)
    
    Args:
        sorted_candidates: Candidate list sorted by some metric (high to low)
        top_k: Number of elements to select
        temperature: Temperature parameter, must > 0. Controls exploration degree (high=more exploration, low=more exploitation)
        get_value: Optional function to extract value from candidate elements (e.g., Q value).
                   If None, assumes elements have `q_value` attribute
        default_value: Default value when unable to get value
        
    Returns:
        Selected element list (count <= top_k)
    """
    if not sorted_candidates:
        return []
    
    if len(sorted_candidates) <= top_k:
        return sorted_candidates
    
    # Extract values
    if get_value is None:
        # Default behavior: assume elements have q_value attribute
        def default_get_value(item: T) -> float:
            if hasattr(item, 'q_value'):
                q_val = getattr(item, 'q_value')
                if q_val is not None and math.isfinite(q_val):
                    return float(q_val)
            return default_value
        get_value = default_get_value
    
    values = [get_value(item) for item in sorted_candidates]
    
    # Convert to numpy array for calculation
    value_array = np.array(values, dtype=np.float64)
    
    # Softmax: exp(value / temperature) / sum(exp(value / temperature))
    # Subtract maximum to avoid numerical overflow
    scaled_values = value_array / temperature
    scaled_values = scaled_values - scaled_values.max()  # Numerical stability
    exp_values = np.exp(scaled_values)
    probs = exp_values / exp_values.sum()
    
    # Ensure probability sum is 1 (handle floating point errors)
    probs = probs / probs.sum()
    
    # Sample top_k according to probability (without replacement)
    k = min(top_k, len(sorted_candidates))
    selected_indices = np.random.choice(
        len(sorted_candidates),
        size=k,
        replace=False,
        p=probs
    )
    
    # Return selected elements
    selected = [sorted_candidates[int(idx)] for idx in selected_indices]
    return selected

