"""Utility functions module"""
import numpy as np
from typing import List, Tuple, Dict


def sample_candidates(bounds: np.ndarray, n_samples: int, types: List[str] = None) -> np.ndarray:
    """Randomly sample candidate points from bounds.
    
    Args:
        bounds: Bounds, shape (d, 2), each row is [min, max]
        n_samples: Number of points to sample
        types: Data type list ["int", "float", ...], None means all float
        
    Returns:
        Candidate point array, shape (n_samples, d)
    """
    d = len(bounds)
    candidates = np.zeros((n_samples, d))
    
    for i, (low, high) in enumerate(bounds):
        if types and types[i] == "int":
            # Integer sampling
            candidates[:, i] = np.random.randint(low, high + 1, size=n_samples)
        else:
            # Float sampling
            candidates[:, i] = np.random.uniform(low, high, size=n_samples)
    
    return candidates


def numpy_to_dict_list(X: np.ndarray, feature_names: List[str]) -> List[dict]:
    """Convert numpy array to list of dictionaries.
    
    Args:
        X: Data array, shape (n, d)
        feature_names: List of feature names
        
    Returns:
        List of dictionaries [{"feat1": val1, "feat2": val2}, ...]
    """
    if X.ndim == 1:
        X = X.reshape(1, -1)
    
    result = []
    for row in X:
        point = {name: val for name, val in zip(feature_names, row)}
        result.append(point)
    
    return result


def dict_list_to_numpy(X_dict: List[dict], feature_names: List[str]) -> np.ndarray:
    """Convert list of dictionaries to numpy array.
    
    Args:
        X_dict: List of dictionaries
        feature_names: List of feature names
        
    Returns:
        Data array, shape (n, d)
    """
    X = np.array([[x[name] for name in feature_names] for x in X_dict])
    return X


def compute_relative_error(y_true: np.ndarray, y_pred: np.ndarray, abs_threshold: float = 1e-6) -> np.ndarray:
    """Compute relative error.
    
    Args:
        y_true: True values
        y_pred: Predicted values
        abs_threshold: Denominator protection threshold
        
    Returns:
        Relative error
    """
    denominator = np.maximum(np.abs(y_true), abs_threshold)
    rel_error = np.abs(y_true - y_pred) / denominator
    return rel_error


def find_high_error_points(
    y_true: np.ndarray, 
    y_pred: np.ndarray, 
    rel_threshold: float = 0.1,
    abs_threshold: float = 0.06,
    top_k: int = 5
) -> List[int]:
    """Find indices of high error points.
    
    Condition: relative error > rel_threshold AND absolute error > abs_threshold
    
    Args:
        y_true: True values
        y_pred: Predicted values
        rel_threshold: Relative error threshold
        abs_threshold: Absolute error threshold
        top_k: Return top k points with largest errors
        
    Returns:
        List of indices of high error points
    """
    # Compute errors
    rel_error = compute_relative_error(y_true, y_pred)
    abs_error = np.abs(y_true - y_pred)
    
    # Find points satisfying conditions
    mask = (rel_error > rel_threshold) & (abs_error > abs_threshold)
    high_error_indices = np.where(mask)[0]
    
    # If number of points exceeds top_k, select k points with largest errors
    if len(high_error_indices) > top_k:
        errors_at_indices = rel_error[high_error_indices]
        sorted_indices = high_error_indices[np.argsort(errors_at_indices)[::-1]]
        high_error_indices = sorted_indices[:top_k]
    
    return high_error_indices.tolist()


def apply_types_to_array(X: np.ndarray, types: List[str]) -> np.ndarray:
    """Convert array data types according to type list.
    
    Args:
        X: Data array, shape (n, d)
        types: Type list ["int", "float", ...]
        
    Returns:
        Converted array
    """
    X_typed = X.copy()
    for i, dtype in enumerate(types):
        if dtype == "int":
            X_typed[:, i] = np.round(X_typed[:, i]).astype(int)
    
    return X_typed


def filter_sandwich_constraints(candidates: np.ndarray, feature_names: List[str]) -> np.ndarray:
    """Filter candidate points that violate hard constraints in Sandwich task.
    
    Constraints:
    - Bread total: 60-140 g (multigrain_bread + whole_wheat_bread + sourdough_bread)
    - Protein total: 60-150 g (chicken_protein + tuna_protein + tofu_protein + hummus_protein + egg_protein)
    - Dairy total ≤ 40 g (low_fat_cheese_dairy + cheddar_cheese + swiss_cheese_dairy)
    - Vegetable total ≥ 100 g (collards + cabbage + onion_vegetables + tomato_vegetables)
    - Sauce+oil total ≤ 25 g (mayo_sauce + olive_oil)
    
    Args:
        candidates: Candidate point array, shape (n, 20)
        feature_names: List of feature names (used to verify if it's Sandwich task)
        
    Returns:
        Filtered candidate point array, shape (m, 20), where m <= n
    """
    # Check if it's Sandwich task
    expected_names = [
        "multigrain_bread", "whole_wheat_bread", "sourdough_bread",
        "chicken_protein", "tuna_protein", "tofu_protein", "hummus_protein", "egg_protein",
        "low_fat_cheese_dairy", "cheddar_cheese", "swiss_cheese_dairy",
        "collards", "cabbage", "onion_vegetables", "tomato_vegetables",
        "mayo_sauce", "olive_oil",
        "apples", "orange", "banana"
    ]
    
    if len(feature_names) != 20 or feature_names != expected_names:
        # Not Sandwich task, return original array
        return candidates
    
    if len(candidates) == 0:
        return candidates
    
    # Ensure all values are non-negative
    candidates = np.maximum(candidates, 0.0)
    
    # Define indices
    bread_indices = [0, 1, 2]  # multigrain_bread, whole_wheat_bread, sourdough_bread
    protein_indices = [3, 4, 5, 6, 7]  # chicken_protein, tuna_protein, tofu_protein, hummus_protein, egg_protein
    dairy_indices = [8, 9, 10]  # low_fat_cheese_dairy, cheddar_cheese, swiss_cheese_dairy
    vegetable_indices = [11, 12, 13, 14]  # collards, cabbage, onion_vegetables, tomato_vegetables
    sauce_indices = [15, 16]  # mayo_sauce, olive_oil
    
    # Compute totals
    bread_total = candidates[:, bread_indices].sum(axis=1)
    protein_total = candidates[:, protein_indices].sum(axis=1)
    dairy_total = candidates[:, dairy_indices].sum(axis=1)
    vegetable_total = candidates[:, vegetable_indices].sum(axis=1)
    sauce_total = candidates[:, sauce_indices].sum(axis=1)
    
    # Check constraints
    mask = (
        (bread_total >= 60.0) & (bread_total <= 140.0) &  # Bread total: 60-140 g
        (protein_total >= 60.0) & (protein_total <= 150.0) &  # Protein total: 60-150 g
        (dairy_total <= 40.0) &  # Dairy total ≤ 40 g
        (vegetable_total >= 100.0) &  # Vegetable total ≥ 100 g
        (sauce_total <= 25.0)  # Sauce+oil total ≤ 25 g
    )
    
    filtered = candidates[mask]
    return filtered
