#!/usr/bin/env python
"""
Improved kernel-target alignment utilities.
Implements centered, weighted KTA with proper normalization.
"""

import numpy as np
from typing import Dict


def centered_weighted_alignment(K: np.ndarray, y: np.ndarray, w: np.ndarray, eps: float = 1e-12) -> float:
    """
    Compute weighted, centered kernel-target alignment.
    
    This is a mathematically correct implementation that:
    1. Properly centers both kernel and target matrices with weights
    2. Uses probability weights (sum to 1)
    3. Returns normalized alignment in [-1, 1]
    
    Args:
        K: (n,n) kernel matrix (symmetric, PSD)
        y: (n,) labels in {-1,+1} or {0,1}
        w: (n,) sample weights >= 0
        eps: Small constant for numerical stability
        
    Returns:
        Normalized alignment scalar in [-1, 1]
    """
    n = K.shape[0]
    y = y.astype(float).copy()
    
    # Convert {0,1} labels to {-1,+1}
    if set(np.unique(y)) == {0.0, 1.0}:
        y = 2*y - 1.0
    
    # Ensure positive weights
    w = np.clip(w, 0.0, None)
    if not np.any(w):  # All weights zero - use uniform
        w = np.ones_like(y)
    
    # Probability weights (sum to 1)
    pi = w / (w.sum() + eps)
    
    # Weighted centering matrix: H = I - 1π^T  
    H = np.eye(n) - np.outer(np.ones(n), pi)
    
    # Weight matrix for element-wise scaling
    W12 = np.sqrt(np.outer(w, w))
    
    # Center kernel matrix with weights
    Kw = H @ (K * W12) @ H
    
    # Create and center target matrix
    Y = np.outer(y, y)
    Yw = H @ (Y * W12) @ H
    
    # Compute alignment
    numerator = np.sum(Kw * Yw)
    denominator = (np.linalg.norm(Kw, 'fro') * np.linalg.norm(Yw, 'fro')) + eps
    
    return numerator / denominator


def try_kta_update(params: np.ndarray, kernel_func, X: np.ndarray, y: np.ndarray, 
                   w: np.ndarray, current_K: np.ndarray, delta: float = 2e-3):
    """
    Try KTA-gated parameter update: only accept if alignment improves.
    
    Args:
        params: New parameter vector to try
        kernel_func: Function to compute kernel matrix from parameters
        X: Training data
        y: Training labels
        w: Sample weights
        current_K: Current kernel matrix
        delta: Minimum improvement threshold
        
    Returns:
        (params, K, alignment, updated): 
        - params: Best parameters (new or old)
        - K: Best kernel matrix 
        - alignment: Best alignment score
        - updated: Whether parameters were updated
    """
    # Compute current alignment
    A_old = centered_weighted_alignment(current_K, y, w)
    
    try:
        # Compute new kernel matrix with proposed parameters
        K_new = kernel_func(X, X, params)
        
        # Compute new alignment
        A_new = centered_weighted_alignment(K_new, y, w)
        
        # Only accept if improvement exceeds threshold
        if A_new >= A_old + delta:
            return params, K_new, A_new, True
        else:
            return None, current_K, A_old, False
            
    except Exception:
        # On numerical errors, reject update
        return None, current_K, A_old, False


def evaluate_alignment_quality(K: np.ndarray, y: np.ndarray, w: np.ndarray) -> Dict[str, float]:
    """
    Evaluate alignment quality and related metrics.
    
    Returns:
        Dictionary with alignment metrics and diagnostics
    """
    alignment = centered_weighted_alignment(K, y, w)
    
    # Additional diagnostics
    eigenvals = np.linalg.eigvals(K)
    condition_number = np.max(eigenvals) / (np.min(eigenvals) + 1e-12)
    
    # Effective sample size of weights
    w_norm = w / (w.sum() + 1e-12)
    ess = 1.0 / (np.sum(w_norm**2) + 1e-12)
    
    # Label balance
    y_binary = (y > 0).astype(int) if np.min(y) >= 0 else (y > 0).astype(int)
    label_balance = np.mean(y_binary)
    
    return {
        'alignment': alignment,
        'kernel_condition': condition_number,
        'effective_sample_size': ess,
        'label_balance': label_balance,
        'kernel_trace': np.trace(K),
        'kernel_frobenius': np.linalg.norm(K, 'fro')
    }


if __name__ == "__main__":
    # Test the improved alignment computation
    print("Testing centered weighted alignment...")
    
    # Generate test data
    np.random.seed(42)
    n = 100
    X = np.random.randn(n, 4)
    y = (X[:, 0] + X[:, 1] > 0).astype(int)
    y_centered = 2*y - 1  # Convert to {-1, +1}
    
    # Create test kernel (RBF)
    gamma = 1.0
    K = np.exp(-gamma * ((X[:, None] - X[None, :])**2).sum(axis=2))
    
    # Test with uniform weights
    w_uniform = np.ones(n)
    align_uniform = centered_weighted_alignment(K, y, w_uniform)
    print(f"Uniform weights alignment: {align_uniform:.4f}")
    
    # Test with random weights  
    w_random = np.random.exponential(1.0, n)
    align_weighted = centered_weighted_alignment(K, y, w_random)
    print(f"Random weights alignment: {align_weighted:.4f}")
    
    # Test with extreme weights (should be more robust)
    w_extreme = np.ones(n)
    w_extreme[:n//4] = 10.0  # High weight for first quarter
    align_extreme = centered_weighted_alignment(K, y, w_extreme)
    print(f"Extreme weights alignment: {align_extreme:.4f}")
    
    # Test alignment quality diagnostics
    quality = evaluate_alignment_quality(K, y, w_uniform)
    print(f"\nAlignment quality diagnostics:")
    for key, value in quality.items():
        print(f"  {key}: {value:.4f}")
    
    print("\n✅ Centered weighted alignment tested successfully!")