"""
Kernel Regression for CDF Estimation

Implements Nadaraya-Watson estimation of F(u) and F'(u).
"""

import numpy as np
from scipy.interpolate import interp1d
from kernel_functions import Kernel


class KernelCDFEstimator:
    """
    Estimate CDF F(u) and its derivative F'(u) using kernel regression.
    
    Given residuals w_t and outcomes y_t, estimates:
        F(u) = 1 - E[y_t | w_t = u]
        F'(u) = derivative of F
    """
    
    def __init__(self, kernel_order, bandwidth):
        """
        Parameters:
        -----------
        kernel_order : int
            Order of kernel (2, 4, or 6)
        bandwidth : float
            Bandwidth parameter b_k
        """
        self.kernel = Kernel(kernel_order)
        self.bandwidth = bandwidth
        self.order = kernel_order
    
    def fit(self, w_hat, y, u_grid):
        """
        Estimate F and F' on a grid of u values.
        
        Parameters:
        -----------
        w_hat : np.ndarray, shape (n,)
            Residuals w_t = p_t - Θ^T x_t
        y : np.ndarray, shape (n,)
            Binary outcomes
        u_grid : np.ndarray, shape (m,)
            Grid of u values to evaluate
        
        Returns:
        --------
        F_hat : np.ndarray, shape (m,)
            Estimated CDF values
        F_prime_hat : np.ndarray, shape (m,)
            Estimated derivative values
        """
        n = len(w_hat)
        b = self.bandwidth
        m = len(u_grid)
        
        # Initialize arrays
        h = np.zeros(m)
        f = np.zeros(m)
        h_prime = np.zeros(m)
        f_prime = np.zeros(m)
        
        # Compute kernel regression for each u in grid
        for i, u in enumerate(u_grid):
            for w_t, y_t in zip(w_hat, y):
                z = (w_t - u) / b
                
                # Only compute if within kernel support
                if abs(z) <= 1:
                    K_val = self.kernel.K(z)
                    K_prime_val = self.kernel.K_prime(z)
                    
                    h[i] += K_val * y_t
                    f[i] += K_val
                    h_prime[i] += K_prime_val * y_t
                    f_prime[i] += K_prime_val
        
        # Normalize
        h = h / (n * b)
        f = f / (n * b)
        h_prime = h_prime / (n * b**2)
        f_prime = f_prime / (n * b**2)
        
        # Compute r(u) = E[y | w = u] and its derivative
        r_hat = self._safe_divide(h, f, default=0.5)
        
        # Quotient rule: r'(u) = (h' * f - h * f') / f²
        numerator = h_prime * f - h * f_prime
        denominator = f**2
        r_prime_hat = self._safe_divide(numerator, denominator, default=0.0)
        
        # F(u) = 1 - r(u), F'(u) = -r'(u)
        F_hat = 1 - r_hat
        F_prime_hat = -r_prime_hat
        
        # Clip F to [0, 1]
        F_hat = np.clip(F_hat, 0, 1)
        
        return F_hat, F_prime_hat
    
    def _safe_divide(self, numerator, denominator, default=0.0, epsilon=1e-10):
        """
        Safe division with default value for small denominators.
        
        Parameters:
        -----------
        numerator : np.ndarray
            Numerator values
        denominator : np.ndarray
            Denominator values
        default : float
            Value to use when denominator is small
        epsilon : float
            Threshold for "small" denominator
        
        Returns:
        --------
        np.ndarray
            Result of division
        """
        result = np.full_like(numerator, default, dtype=float)
        mask = np.abs(denominator) > epsilon
        result[mask] = numerator[mask] / denominator[mask]
        return result
    
    def create_interpolators(self, u_grid, F_hat, F_prime_hat):
        """
        Create interpolation functions for F and F'.
        
        Parameters:
        -----------
        u_grid : np.ndarray
            Grid points
        F_hat : np.ndarray
            F values
        F_prime_hat : np.ndarray
            F' values
        
        Returns:
        --------
        F_interp : callable
            Interpolation function for F
        F_prime_interp : callable
            Interpolation function for F'
        """
        F_interp = interp1d(
            u_grid, F_hat,
            kind='linear',
            bounds_error=False,
            fill_value=(F_hat[0], F_hat[-1])
        )
        
        F_prime_interp = interp1d(
            u_grid, F_prime_hat,
            kind='linear',
            bounds_error=False,
            fill_value=(F_prime_hat[0], F_prime_hat[-1])
        )
        
        return F_interp, F_prime_interp


def test_kernel_regression():
    """Test kernel regression with synthetic data."""
    print("="*60)
    print("Testing Kernel CDF Estimation")
    print("="*60)
    
    # Generate synthetic data
    # True model: w_t ~ N(0, 0.1²), y_t = 1{w_t ≤ 0}
    # So F(u) = Φ(u/0.1) where Φ is standard normal CDF
    
    np.random.seed(42)
    n = 1000
    sigma = 0.1
    
    # Generate residuals from N(0, σ²)
    w_hat = np.random.normal(0, sigma, size=n)
    
    # Generate outcomes: y = 1 if noise ≥ w
    noise = np.random.normal(0, sigma, size=n)
    y = (noise >= w_hat).astype(float)
    
    # True CDF at u: F(u) = P(noise < u) = Φ(u/σ)
    from scipy.stats import norm
    def true_F(u):
        return norm.cdf(u / sigma)
    
    def true_F_prime(u):
        return norm.pdf(u / sigma) / sigma
    
    # Estimate F using kernel regression
    u_grid = np.linspace(-0.3, 0.3, 100)
    
    for order in [2, 4, 6]:
        print(f"\n--- Kernel order {order} ---")
        
        # Compute bandwidth
        bandwidth = 3 * n ** (-1 / (2 * order + 1))
        print(f"Bandwidth: {bandwidth:.6f}")
        
        # Fit estimator
        estimator = KernelCDFEstimator(order, bandwidth)
        F_hat, F_prime_hat = estimator.fit(w_hat, y, u_grid)
        
        # Compute true values
        F_true = true_F(u_grid)
        F_prime_true = true_F_prime(u_grid)
        
        # Compute errors
        F_error = np.mean(np.abs(F_hat - F_true))
        F_prime_error = np.mean(np.abs(F_prime_hat - F_prime_true))
        
        print(f"Mean absolute error in F: {F_error:.6f}")
        print(f"Mean absolute error in F': {F_prime_error:.6f}")
        
        # Print some values
        print(f"\nSample values:")
        for i in [25, 50, 75]:
            u = u_grid[i]
            print(f"  u={u:6.3f}: F_hat={F_hat[i]:.4f} (true={F_true[i]:.4f}), "
                  f"F'_hat={F_prime_hat[i]:.4f} (true={F_prime_true[i]:.4f})")
        
        # Test interpolators
        F_interp, F_prime_interp = estimator.create_interpolators(
            u_grid, F_hat, F_prime_hat
        )
        
        u_test = 0.05
        F_interp_val = F_interp(u_test)
        F_prime_interp_val = F_prime_interp(u_test)
        print(f"\nInterpolation at u={u_test}:")
        print(f"  F({u_test}) = {F_interp_val:.4f}")
        print(f"  F'({u_test}) = {F_prime_interp_val:.4f}")


if __name__ == "__main__":
    test_kernel_regression()
