"""Acquisition function module - q-EI implementation"""
import numpy as np
from scipy.stats import norm
from typing import Tuple, List


def compute_ei(mu: np.ndarray, sigma: np.ndarray, y_best: float, xi: float = 0.01) -> np.ndarray:
    """Compute Expected Improvement (EI).
    
    Args:
        mu: Predicted mean, shape (n,)
        sigma: Predicted standard deviation, shape (n,)
        y_best: Current best value
        xi: Exploration parameter
        
    Returns:
        EI values, shape (n,)
    """
    # Avoid division by zero
    sigma = np.maximum(sigma, 1e-9)
    
    # Compute improvement
    improvement = mu - y_best - xi
    Z = improvement / sigma
    
    # EI formula
    ei = improvement * norm.cdf(Z) + sigma * norm.pdf(Z)
    
    return ei


def compute_q_ei_greedy(
    candidates: np.ndarray,
    mu: np.ndarray,
    sigma: np.ndarray,
    y_best: float,
    q: int = 2,
    xi: float = 0.01,
    update_y_best: bool = True
) -> List[int]:
    """Greedily compute q-EI (select q points).
    
    Strategy:
    1. Select the point with highest EI
    2. Assume the observed value is its predicted mean
    3. Update y_best
    4. Repeat until q points are selected
    
    Args:
        candidates: Candidate points, shape (n, d)
        mu: Predicted mean, shape (n,)
        sigma: Predicted standard deviation, shape (n,)
        y_best: Current best value
        q: Number of points to select
        xi: Exploration parameter
        update_y_best: Whether to update y_best during iteration (default True, set False for comparison)
        
    Returns:
        List of selected point indices
    """
    n_candidates = len(candidates)
    selected_indices = []
    available_mask = np.ones(n_candidates, dtype=bool)
    current_y_best = y_best
    
    for _ in range(q):
        # Compute EI for available points
        ei_values = np.full(n_candidates, -np.inf)
        ei_values[available_mask] = compute_ei(
            mu[available_mask], 
            sigma[available_mask], 
            current_y_best, 
            xi
        )
        
        # Select point with highest EI
        best_idx = np.argmax(ei_values)
        selected_indices.append(best_idx)
        
        # Mark as unavailable
        available_mask[best_idx] = False
        
        # Update y_best (assuming observed value equals predicted mean)
        if update_y_best:
            current_y_best = max(current_y_best, mu[best_idx])
    
    return selected_indices


def select_next_points_q_ei(
    candidates: np.ndarray,
    mu: np.ndarray,
    variance: np.ndarray,
    y_best: float,
    q: int = 2,
    update_y_best: bool = True
) -> Tuple[List[int], np.ndarray]:
    """Select next batch of experiment points using q-EI.
    
    Args:
        candidates: Candidate point array, shape (n, d)
        mu: Predicted mean, shape (n,)
        variance: Predicted variance, shape (n,)
        y_best: Current best value
        q: Number of points to select
        update_y_best: Whether to update y_best during iteration
        
    Returns:
        (selected_indices, ei_values) Selected indices and EI values for all points
    """
    # Convert variance to standard deviation
    sigma = np.sqrt(variance)
    
    # Compute q-EI (greedy method)
    selected_indices = compute_q_ei_greedy(candidates, mu, sigma, y_best, q, update_y_best=update_y_best)
    
    # Also return initial EI values for all points (for analysis)
    ei_values = compute_ei(mu, sigma, y_best)
    
    return selected_indices, ei_values


def compute_ucb(mu: np.ndarray, sigma: np.ndarray, beta: float = 2.0) -> np.ndarray:
    """Compute Upper Confidence Bound (UCB).
    
    Args:
        mu: Predicted mean, shape (n,)
        sigma: Predicted standard deviation, shape (n,)
        beta: Exploration parameter (default 2.0, corresponding to 95% confidence interval)
        
    Returns:
        UCB values, shape (n,)
    """
    # Avoid sigma being zero
    sigma = np.maximum(sigma, 1e-9)
    
    # UCB formula: μ + β * σ
    ucb = mu + beta * sigma
    
    return ucb


def compute_q_ucb_greedy(
    candidates: np.ndarray,
    mu: np.ndarray,
    sigma: np.ndarray,
    q: int = 2,
    beta: float = 2.0
) -> List[int]:
    """Greedily compute q-UCB (select q points).
    
    Strategy:
    1. Select the point with highest UCB
    2. Mark as unavailable
    3. Repeat until q points are selected
    
    Args:
        candidates: Candidate points, shape (n, d)
        mu: Predicted mean, shape (n,)
        sigma: Predicted standard deviation, shape (n,)
        q: Number of points to select
        beta: Exploration parameter
        
    Returns:
        List of selected point indices
    """
    n_candidates = len(candidates)
    selected_indices = []
    available_mask = np.ones(n_candidates, dtype=bool)
    
    for _ in range(q):
        # Compute UCB for available points
        ucb_values = np.full(n_candidates, -np.inf)
        ucb_values[available_mask] = compute_ucb(
            mu[available_mask], 
            sigma[available_mask], 
            beta
        )
        
        # Select point with highest UCB
        best_idx = np.argmax(ucb_values)
        selected_indices.append(best_idx)
        
        # Mark as unavailable
        available_mask[best_idx] = False
    
    return selected_indices


def select_next_points_q_ucb(
    candidates: np.ndarray,
    mu: np.ndarray,
    variance: np.ndarray,
    q: int = 2,
    beta: float = 2.0
) -> Tuple[List[int], np.ndarray]:
    """Select next batch of experiment points using q-UCB.
    
    Args:
        candidates: Candidate point array, shape (n, d)
        mu: Predicted mean, shape (n,)
        variance: Predicted variance, shape (n,)
        q: Number of points to select
        beta: Exploration parameter
        
    Returns:
        (selected_indices, ucb_values) Selected indices and UCB values for all points
    """
    # Convert variance to standard deviation
    sigma = np.sqrt(variance)
    
    # Compute q-UCB (greedy method)
    selected_indices = compute_q_ucb_greedy(candidates, mu, sigma, q, beta)
    
    # Also return UCB values for all points (for analysis)
    ucb_values = compute_ucb(mu, sigma, beta)
    
    return selected_indices, ucb_values
