"""Decision module - mismatch judgment and HF/LF decision"""
import numpy as np
from typing import List, Tuple


class MismatchDecision:
    """Mismatch decision maker - decides whether to perform HF experiment based on mismatch ratio r."""
    
    def __init__(self, threshold: float = 0.8, consecutive_high_limit: int = 3, 
                 force_hf_after_n_lf: int = None):
        """Initialize decision maker.
        
        Args:
            threshold: Mismatch ratio threshold (r > threshold → HF experiment)
            consecutive_high_limit: Limit on consecutive high r counts (unused, kept for compatibility)
            force_hf_after_n_lf: Force HF after N consecutive LF (None means no forcing)
        """
        self.threshold = threshold
        self.consecutive_high_count = 0
        self.consecutive_lf_count = 0  # Consecutive LF count
        self.force_hf_after_n_lf = force_hf_after_n_lf
    
    def compute_mismatch_ratio(
        self, 
        sigma2_delta: np.ndarray, 
        sigma2_H: np.ndarray
    ) -> np.ndarray:
        """Compute mismatch ratio r = σ²_δ / σ²_H.
        
        Args:
            sigma2_delta: Residual variance, shape (n,) or scalar
            sigma2_H: Posterior high-fidelity variance, shape (n,) or scalar
            
        Returns:
            Mismatch ratio r, shape (n,) or scalar
        """
        # Avoid division by zero
        sigma2_H = np.maximum(sigma2_H, 1e-9)
        r = sigma2_delta / sigma2_H
        
        # Clip to [0, 1] range
        r = np.clip(r, 0, 1)
        
        return r
    
    def decide(
        self,
        selected_indices: List[int],
        sigma2_delta: np.ndarray,
        sigma2_H: np.ndarray
    ) -> Tuple[bool, float, List[float]]:
        """Decide whether to perform HF experiment.
        
        Args:
            selected_indices: Indices of selected points
            sigma2_delta: Residual variance array
            sigma2_H: Posterior high-fidelity variance array
            
        Returns:
            (do_hf, r_max, r_values)
            - do_hf: Whether to perform HF experiment
            - r_max: Maximum mismatch ratio
            - r_values: Mismatch ratios for all selected points
        """
        # Compute mismatch ratios for selected points
        r_values = []
        for idx in selected_indices:
            r = self.compute_mismatch_ratio(sigma2_delta[idx], sigma2_H[idx])
            r_values.append(float(r))
        
        # Take maximum value
        r_max = max(r_values)
        
        # Decision: based on threshold or forcing mechanism
        if self.force_hf_after_n_lf is not None and self.consecutive_lf_count >= self.force_hf_after_n_lf:
            # Force HF mechanism: force HF after N consecutive LF
            do_hf = True
        else:
            # Normal decision: based on threshold
            do_hf = r_max >= self.threshold
        
        # Update consecutive counts
        if do_hf:
            self.consecutive_high_count += 1
            self.consecutive_lf_count = 0  # Reset LF count
        else:
            self.consecutive_high_count = 0
            self.consecutive_lf_count += 1  # Increment LF count
        
        return do_hf, r_max, r_values
    

