"""
Simple Principled Ensemble Sizing Strategy

Core principle: Each true pair needs at least 1 vote per iteration with high confidence.

Formula: n_ens = ceil(-log(1 - confidence) / (r * w))
Where:
  - r = subset_ratio (fraction of reference points per ensemble)
  - w = win_probability (probability true match wins NN when in subset)
  - confidence = target probability of getting at least 1 vote (default 0.95)

For 95% confidence: n_ens ≈ ceil(3 / (r * w))
"""

import math
from typing import Optional, Dict, List
import numpy as np
from scipy import sparse


def compute_n_ensembles_principled(
    subset_ratio: float,
    win_probability: float,
    confidence: float = 0.95,
    min_ensembles: int = 5,
    max_ensembles: int = 50
) -> int:
    """
    Principled ensemble count: ensure true pairs get ≥1 vote/iteration with confidence.

    Derivation:
        P(at least 1 vote) = 1 - (1-p)^n_ens ≥ confidence
        (1-p)^n_ens ≤ 1 - confidence
        n_ens ≥ log(1 - confidence) / log(1 - p)
        For small p: n_ens ≈ -log(1 - confidence) / p

    Args:
        subset_ratio: r, fraction of reference points per ensemble
        win_probability: w, probability true match wins NN when in subset
        confidence: target probability of getting at least 1 vote (default 0.95)
        min_ensembles: minimum ensemble count (default 5)
        max_ensembles: maximum ensemble count (default 50)

    Returns:
        n_ensembles: number of ensembles needed
    """
    p = subset_ratio * win_probability
    if p <= 0:
        return max_ensembles

    # Exact formula: n_ens = ceil(-log(1 - confidence) / p)
    # For 95%: -log(0.05) ≈ 2.996 ≈ 3
    # For 99%: -log(0.01) ≈ 4.605 ≈ 5
    n_ens = math.ceil(-math.log(1 - confidence) / p)

    return max(min_ensembles, min(max_ensembles, n_ens))


def estimate_initial_win_probability(d_intrinsic: float = 30) -> float:
    """
    Estimate initial win probability for first iteration.

    Higher intrinsic dimensionality makes NN matching harder,
    so we reduce w accordingly.

    Args:
        d_intrinsic: intrinsic dimensionality of embedding space

    Returns:
        w: estimated win probability (conservative)
    """
    # Base assumption: w = 0.25 at d=20
    # Scale down for higher dimensions
    base_w = 0.25
    dim_penalty = math.sqrt(max(1, d_intrinsic) / 20)  # 1.0 at d=20, 1.22 at d=30
    return max(0.1, base_w / dim_penalty)


class SimplePrincipledStrategy:
    """
    Simple principled ensemble sizing based on vote confidence.

    Key principle: Each true pair should receive at least 1 vote per iteration
    with high probability (95% by default).

    The number of ensembles is determined by:
        n_ens = ceil(3 / (r * w))

    Where:
        r = subset_ratio (fixed at 0.4)
        w = win_probability (estimated from MNN ratio after iteration 1)
    """

    def __init__(self, d_intrinsic: float = 30, confidence: float = 0.95):
        """
        Initialize principled strategy.

        Args:
            d_intrinsic: intrinsic dimensionality of embeddings
            confidence: target confidence for vote reliability (default 0.95)
        """
        self.d_intrinsic = d_intrinsic
        self.confidence = confidence
        self.mnn_ratio_history: List[float] = []
        self.accuracy_history: List[float] = []

    def estimate_win_probability(self, iteration: int) -> float:
        """
        Estimate the win probability w.

        For iteration 1: use conservative estimate based on intrinsic dim
        After iteration 1: use observed MNN ratio as direct proxy

        Args:
            iteration: current iteration number (1-based)

        Returns:
            w: estimated win probability
        """
        if iteration == 1 or not self.mnn_ratio_history:
            # Conservative initial estimate
            return estimate_initial_win_probability(self.d_intrinsic)
        else:
            # Use observed MNN ratio (average of recent 3)
            recent = self.mnn_ratio_history[-3:]
            return max(0.1, float(np.mean(recent)))

    def get_params(self, n_ref: int, iteration: int) -> Dict:
        """
        Get ensemble parameters for this iteration.

        Args:
            n_ref: current number of reference points
            iteration: current iteration number (1-based)

        Returns:
            dict with:
            - n_ensembles: number of ensembles to use
            - subset_ratio: fraction of references per ensemble
            - win_prob_estimate: estimated w value
        """
        # Fixed subset ratio (simple and effective)
        r = 0.4

        # Estimate win probability
        w = self.estimate_win_probability(iteration)

        # Compute ensemble count from principled formula
        n_ens = compute_n_ensembles_principled(
            subset_ratio=r,
            win_probability=w,
            confidence=self.confidence
        )

        return {
            'n_ensembles': n_ens,
            'subset_ratio': r,
            'win_prob_estimate': w
        }

    def update(self, mnn_ratio: float, accuracy: float = None) -> None:
        """
        Update strategy with observed metrics after an iteration.

        Args:
            mnn_ratio: observed MNN ratio (mutual_nn / total_points)
            accuracy: observed accuracy this iteration (optional, for tracking)
        """
        if mnn_ratio is not None and mnn_ratio > 0:
            self.mnn_ratio_history.append(float(mnn_ratio))

        if accuracy is not None:
            self.accuracy_history.append(float(accuracy))

    def get_summary(self) -> Dict:
        """Get summary of strategy state for logging."""
        return {
            'iterations': len(self.mnn_ratio_history),
            'avg_mnn_ratio': float(np.mean(self.mnn_ratio_history)) if self.mnn_ratio_history else 0,
            'avg_accuracy': float(np.mean(self.accuracy_history)) if self.accuracy_history else 0,
            'd_intrinsic': self.d_intrinsic,
            'confidence': self.confidence
        }


# ============================================================================
# Legacy compatibility - keep old class for backward compatibility
# ============================================================================

def compute_mnn_consistency(vote_matrix: sparse.spmatrix, n_ensembles: int) -> Dict[str, float]:
    """
    Extract observable consistency metrics from voting results.
    (Kept for backward compatibility)
    """
    if vote_matrix is None or vote_matrix.nnz == 0:
        return {
            'consistency_ratio': 0.0,
            'high_confidence_fraction': 0.0
        }

    vote_data = vote_matrix.data
    consistency_ratio = np.mean(vote_data) / max(1, n_ensembles)
    threshold = n_ensembles * 0.5
    high_conf_fraction = np.mean(vote_data >= threshold)

    return {
        'consistency_ratio': float(consistency_ratio),
        'high_confidence_fraction': float(high_conf_fraction)
    }


# ============================================================================
# Principled Fixed Strategy (Theoretically Justified)
# ============================================================================

def compute_principled_fixed_strategy_params(
    growth_ratio: float,
    initial_subset_ratio: float,
    initial_n_ensembles: int,
    initial_ref_size: int,
    d_intrinsic: float = 30.0,
    gamma: float = 2.0,
    confidence: float = 0.95,
    fixed_c: Optional[float] = None
) -> Dict:
    """
    Principled fixed strategy parameters derived from theory.

    Computes adaptive ensemble parameters as reference set grows:
        r(g) = r₀ / s^β_r     (subset ratio decays)
        B(g) = B₀ × s^β_B     (ensemble count grows)

    where s = 1 + c·log(g) is the scale factor and g = n/n₀ is growth ratio.

    Args:
        growth_ratio: g = n/n₀, how much reference set has grown
        initial_subset_ratio: r₀, base subset ratio (default 0.4)
        initial_n_ensembles: B₀, base ensemble count (default 5)
        initial_ref_size: n₀, initial reference set size
        d_intrinsic: Intrinsic dimensionality of embeddings (default 30)
        gamma: F-score preference parameter (default 2.0 = F2, recall-weighted)
        confidence: Target vote confidence level (default 0.95)
        fixed_c: Optional fixed value for log coefficient c (default None = use heuristic)

    Returns:
        dict with:
        - subset_ratio: Adaptive r(g) = max(min_r, r₀/s^β_r)
        - n_ensembles: Adaptive B(g) = min(max_B, B₀·s^β_B)
        - scale_factor: s = 1 + c·log(g)
        - coefficient_c: Derived log coefficient
        - beta_r: Ratio scaling exponent (≈0.9 for γ=2)
        - beta_B: Ensemble scaling exponent (=1.0)
        - min_ratio: JL lower bound
        - max_ensembles: Vote confidence upper bound
    """
    # 1. Principled log coefficient (replaces heuristic 0.4)
    # c = |log(r₀)| / log(n₀)
    # For r₀=0.4, n₀=100: c = 0.916/4.605 ≈ 0.2
    # With β=2 factor: c ≈ 0.4 (matches original heuristic!)
    if fixed_c is not None:
        # Use provided fixed c value (validate range)
        c = max(0.1, min(2.0, float(fixed_c)))
    else:
        # Compute using heuristic formula (existing behavior)
        beta_coef = 1.0
        c = beta_coef * abs(np.log(initial_subset_ratio)) / np.log(max(initial_ref_size, 2)) # log(0.4) / log(30)
        c = max(0.1, min(2.0, c))  # Bound to reasonable range

    # 2. Scale factor: s = 1 + c * log(g)
    scale_factor = 1 + c * np.log1p(growth_ratio - 1)

    # 3. Principled scaling exponents from F-score optimization
    # β* = 0.5 + 0.5γ² / (1 + γ²)
    # For γ=0.5 (precision): β=0.6
    # For γ=1.0 (balanced): β=0.75
    # For γ=2.0 (recall): β=0.9
    beta_r = 0.5 + 0.5 * gamma**2 / (1 + gamma**2)

    # For ensembles: always β=1 due to error asymmetry
    # (false negatives persist, false positives self-correct)
    beta_B = 1.0

    # 4. Principled minimum ratio (Nyquist-like: need k ≥ 2d)
    # However, cap at practical minimum to avoid too aggressive constraints
    current_ref_size = initial_ref_size * growth_ratio
    theoretical_min = 2 * d_intrinsic / max(current_ref_size, 1)
    # Use practical cap: min(theoretical, 0.1) ensures we don't require >10% of refs
    # unless we have very few refs (< 20*d), in which case theoretical dominates
    min_ratio = max(0.02, min(0.1, theoretical_min))

    # 5. Principled maximum ensembles from vote confidence
    # n_ens = ceil(-log(1 - confidence) / (r * w))
    w_estimate = 0.15  # Conservative win probability
    log_term = -np.log(1 - confidence)  # ~3 for 95%
    max_ens_from_confidence = int(np.ceil(log_term / (min_ratio * w_estimate)))
    max_ens = min(max_ens_from_confidence, 100)  # Computational cap

    # 6. Compute adaptive parameters
    # Subset ratio decays as r₀ / s^β_r
    adaptive_ratio = max(min_ratio, initial_subset_ratio / (scale_factor))

    # Ensemble count grows as B₀ * s^β_B
    adaptive_n_ens = min(max_ens, int(initial_n_ensembles * (scale_factor)))

    return {
        'subset_ratio': adaptive_ratio,
        'n_ensembles': adaptive_n_ens,
        'scale_factor': scale_factor,
        'coefficient_c': c,
        'beta_r': beta_r,
        'beta_B': beta_B,
        'min_ratio': min_ratio,
        'max_ensembles': max_ens,
        'gamma': gamma,
        'confidence': confidence,
        'd_intrinsic': d_intrinsic
    }


class AdaptiveEnsembleStrategy:
    """
    Legacy adaptive ensemble strategy.
    Now wraps SimplePrincipledStrategy for backward compatibility.
    """

    def __init__(self, initial_seed_quality: float, d_intrinsic: float):
        """Initialize with legacy parameters."""
        self.initial_seed_quality = initial_seed_quality
        self.d_intrinsic = d_intrinsic
        self._strategy = SimplePrincipledStrategy(d_intrinsic=d_intrinsic)

    def update(
        self,
        vote_matrix: Optional[sparse.spmatrix],
        n_ensembles: int,
        accuracy: float,
        mnn_ratio: Optional[float] = None
    ) -> None:
        """Update with observed metrics."""
        # Use mnn_ratio directly if provided
        if mnn_ratio is not None:
            self._strategy.update(mnn_ratio=mnn_ratio, accuracy=accuracy)
        elif vote_matrix is not None:
            # Fall back to computing from vote_matrix
            stats = compute_mnn_consistency(vote_matrix, n_ensembles)
            self._strategy.update(mnn_ratio=stats['consistency_ratio'], accuracy=accuracy)

    def get_params(self, n_ref: int, iteration: int) -> Dict:
        """Get ensemble parameters."""
        params = self._strategy.get_params(n_ref, iteration)
        # Add legacy fields
        params['mnn_consistency'] = None
        params['confidence_level'] = self._strategy.confidence
        return params

    def get_summary(self) -> Dict:
        """Get summary for logging."""
        summary = self._strategy.get_summary()
        summary['initial_seed_quality'] = self.initial_seed_quality
        return summary
