"""
Kernel Functions for Nonparametric Estimation

Implements 2nd, 4th, and 6th order kernels from Fan-Guo-Yu 2022.
"""

import numpy as np


class Kernel:
    """
    Higher-order kernel functions for nonparametric estimation.
    
    All kernels have support [-1, 1] and integrate to 1.
    """
    
    def __init__(self, order):
        """
        Parameters:
        -----------
        order : int
            Kernel order (2, 4, or 6)
        """
        if order not in [2, 4, 6]:
            raise ValueError("Kernel order must be 2, 4, or 6")
        
        self.order = order
    
    def K(self, u):
        """
        Evaluate kernel function.
        
        Parameters:
        -----------
        u : float or np.ndarray
            Point(s) to evaluate
        
        Returns:
        --------
        float or np.ndarray
            Kernel value(s)
        """
        u = np.asarray(u)
        
        if self.order == 2:
            return self._K2(u)
        elif self.order == 4:
            return self._K4(u)
        else:  # order == 6
            return self._K6(u)
    
    def K_prime(self, u):
        """
        Evaluate kernel derivative.
        
        Parameters:
        -----------
        u : float or np.ndarray
            Point(s) to evaluate
        
        Returns:
        --------
        float or np.ndarray
            Kernel derivative value(s)
        """
        u = np.asarray(u)
        
        if self.order == 2:
            return self._K2_prime(u)
        elif self.order == 4:
            return self._K4_prime(u)
        else:  # order == 6
            return self._K6_prime(u)
    
    # ===== Second-order kernel =====
    
    def _K2(self, u):
        """
        K_2(u) = (35/12) * (1 - u²)³ * I{|u| ≤ 1}
        """
        result = np.zeros_like(u, dtype=float)
        mask = np.abs(u) <= 1
        
        if np.any(mask):
            u_valid = u[mask]
            result[mask] = (35.0 / 12.0) * (1 - u_valid**2)**3
        
        return result if u.shape else float(result)
    
    def _K2_prime(self, u):
        """
        K'_2(u) = -(35/2) * u * (1 - u²)² * I{|u| ≤ 1}
        """
        result = np.zeros_like(u, dtype=float)
        mask = np.abs(u) <= 1
        
        if np.any(mask):
            u_valid = u[mask]
            result[mask] = -(35.0 / 2.0) * u_valid * (1 - u_valid**2)**2
        
        return result if u.shape else float(result)
    
    # ===== Fourth-order kernel =====
    
    def _K4(self, u):
        """
        K_4(u) = (27/16) * (1 - 11u²/3) * K_2(u)
        """
        A4 = (27.0 / 16.0) * (1 - (11.0 / 3.0) * u**2)
        return A4 * self._K2(u)
    
    def _K4_prime(self, u):
        """
        K'_4(u) = A'_4(u) * K_2(u) + A_4(u) * K'_2(u)
        where A_4(u) = (27/16) * (1 - 11u²/3)
        """
        # A_4 and its derivative
        A4 = (27.0 / 16.0) * (1 - (11.0 / 3.0) * u**2)
        A4_prime = -(99.0 / 8.0) * u
        
        # Product rule
        return A4_prime * self._K2(u) + A4 * self._K2_prime(u)
    
    # ===== Sixth-order kernel =====
    
    def _K6(self, u):
        """
        K_6(u) = (297/128) * (1 - 26u²/3 + 13u⁴) * K_2(u)
        """
        A6 = (297.0 / 128.0) * (1 - (26.0 / 3.0) * u**2 + 13 * u**4)
        return A6 * self._K2(u)
    
    def _K6_prime(self, u):
        """
        K'_6(u) = A'_6(u) * K_2(u) + A_6(u) * K'_2(u)
        where A_6(u) = (297/128) * (1 - 26u²/3 + 13u⁴)
        """
        # A_6 and its derivative
        A6 = (297.0 / 128.0) * (1 - (26.0 / 3.0) * u**2 + 13 * u**4)
        A6_prime = (297.0 / 128.0) * 52 * (u**3 - u / 3.0)
        
        # Product rule
        return A6_prime * self._K2(u) + A6 * self._K2_prime(u)


def test_kernels():
    """Test kernel functions and their derivatives."""
    print("="*60)
    print("Testing Kernel Functions")
    print("="*60)
    
    u_test = np.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
    
    for order in [2, 4, 6]:
        print(f"\n--- Order {order} Kernel ---")
        kernel = Kernel(order)
        
        # Evaluate kernel
        K_vals = kernel.K(u_test)
        print(f"\nK_{order}(u):")
        for u, K in zip(u_test, K_vals):
            print(f"  u={u:5.2f}: K={K:8.4f}")
        
        # Evaluate derivative
        K_prime_vals = kernel.K_prime(u_test)
        print(f"\nK'_{order}(u):")
        for u, Kp in zip(u_test, K_prime_vals):
            print(f"  u={u:5.2f}: K'={Kp:8.4f}")
        
        # Check properties
        # 1. Support: K(u) = 0 for |u| > 1
        assert np.all(K_vals[np.abs(u_test) > 1] == 0), "Kernel should be 0 outside [-1,1]"
        
        # 2. Symmetry: K(u) = K(-u)
        u_sym = np.array([-0.5, -0.25, 0.0, 0.25, 0.5])
        K_sym = kernel.K(u_sym)
        assert np.allclose(K_sym, K_sym[::-1]), "Kernel should be symmetric"
        
        # 3. Anti-symmetry of derivative: K'(u) = -K'(-u)
        Kp_sym = kernel.K_prime(u_sym)
        assert np.allclose(Kp_sym, -Kp_sym[::-1], atol=1e-10), "K' should be anti-symmetric"
        
        # 4. Numerical derivative check
        eps = 1e-6
        u_check = 0.3
        K_numerical = (kernel.K(u_check + eps) - kernel.K(u_check - eps)) / (2 * eps)
        K_analytical = kernel.K_prime(u_check)
        print(f"\nDerivative check at u={u_check}:")
        print(f"  Numerical: {K_numerical:.6f}")
        print(f"  Analytical: {K_analytical:.6f}")
        print(f"  Error: {abs(K_numerical - K_analytical):.2e}")
        assert abs(K_numerical - K_analytical) < 1e-4, "Derivative doesn't match numerical"
        
        print(f"\n✓ All tests passed for order {order}")


if __name__ == "__main__":
    test_kernels()
