"""
Phi Function and its Inverse

Implements the pricing function construction from Fan-Guo-Yu 2022.
"""

import numpy as np
from scipy.interpolate import interp1d


class PhiInverter:
    """
    Compute φ(u) and its inverse φ^{-1}(s).
    
    φ(u) = u - (1 - F(u)) / F'(u)
    
    The inverse is computed using Newton's method with logistic reparametrization.
    """
    
    def __init__(self, F_interp, F_prime_interp, u_min=-0.5, u_max=0.5):
        """
        Parameters:
        -----------
        F_interp : callable
            Interpolation function for F(u)
        F_prime_interp : callable
            Interpolation function for F'(u)
        u_min, u_max : float
            Support of the noise distribution
        """
        self.F_interp = F_interp
        self.F_prime_interp = F_prime_interp
        self.u_min = u_min
        self.u_max = u_max
    
    def phi(self, u):
        """
        Evaluate φ(u) = u - (1 - F(u)) / F'(u).
        
        Parameters:
        -----------
        u : float or np.ndarray
            Point(s) to evaluate
        
        Returns:
        --------
        float or np.ndarray
            φ(u) value(s)
        """
        F_val = self.F_interp(u)
        F_prime_val = self.F_prime_interp(u)
        
        # Handle F'(u) ≈ 0
        epsilon = 1e-10
        if np.isscalar(u):
            if abs(F_prime_val) < epsilon:
                return u  # Fallback
            return u - (1 - F_val) / F_prime_val
        else:
            result = np.array(u)
            mask = np.abs(F_prime_val) > epsilon
            result[mask] = u[mask] - (1 - F_val[mask]) / F_prime_val[mask]
            return result
    
    def phi_prime(self, u, delta=0.005):
        """
        Estimate φ'(u) using finite differences.
        
        Parameters:
        -----------
        u : float
            Point to evaluate
        delta : float
            Step size for finite difference
        
        Returns:
        --------
        float
            φ'(u) estimate
        """
        phi_plus = self.phi(u + delta)
        phi_minus = self.phi(u - delta)
        return (phi_plus - phi_minus) / (2 * delta)
    
    def solve_phi_inverse(self, s, max_iter=20, tol=1e-8, verbose=False):
        """
        Solve φ(x) = s for x using Newton's method with logistic transform.
        
        We reparametrize x ∈ (-1, 1) as:
            x(y) = -1 + 2/(1 + exp(y))
        
        Then solve g(y) = φ(x(y)) - s = 0 using Newton's method.
        
        Parameters:
        -----------
        s : float
            Target value
        max_iter : int
            Maximum iterations
        tol : float
            Convergence tolerance
        verbose : bool
            Print iteration details
        
        Returns:
        --------
        float
            x such that φ(x) ≈ s
        """
        # Start at y = 0 (which gives x = 0)
        y = 0.0
        
        for i in range(max_iter):
            # Compute x from y
            exp_y = np.exp(y)
            x = -1 + 2 / (1 + exp_y)
            
            # Derivative of x w.r.t. y
            x_prime = -2 * exp_y / (1 + exp_y)**2
            
            # Evaluate g(y) = φ(x(y)) - s
            phi_val = self.phi(x)
            g = phi_val - s
            
            if verbose:
                print(f"Iter {i}: y={y:.6f}, x={x:.6f}, φ(x)={phi_val:.6f}, g={g:.6f}")
            
            # Check convergence
            if abs(g) < tol:
                if verbose:
                    print(f"Converged in {i+1} iterations")
                return x
            
            # Compute g'(y) = φ'(x) * x'(y)
            phi_prime_val = self.phi_prime(x)
            g_prime = phi_prime_val * x_prime
            
            # Check for small derivative
            if abs(g_prime) < 1e-12:
                if verbose:
                    print(f"Small derivative, returning current x")
                return x
            
            # Newton step
            y_new = y - g / g_prime
            
            # Safeguard against large steps
            if abs(y_new - y) > 5:
                y_new = y - 5 * np.sign(g / g_prime)
            
            y = y_new
        
        # If not converged, return best estimate
        x = -1 + 2 / (1 + np.exp(y))
        if verbose:
            print(f"Did not converge in {max_iter} iterations, returning x={x:.6f}")
        return x
    
    def build_g_function(self):
        """
        Build the pricing function g(u) = u + φ^{-1}(-u).
        
        Returns:
        --------
        callable
            Function that computes optimal price shift
        """
        def g(u):
            """
            Optimal price function.
            
            Parameters:
            -----------
            u : float or np.ndarray
                Base value (Θ^T x)
            
            Returns:
            --------
            float or np.ndarray
                Optimal price shift
            """
            if np.isscalar(u):
                phi_inv_val = self.solve_phi_inverse(-u)
                return u + phi_inv_val
            else:
                return np.array([self.build_g_function()(ui) for ui in u])
        
        return g


def test_phi_inverter():
    """Test phi inversion with a known distribution."""
    print("="*60)
    print("Testing Phi Inversion")
    print("="*60)
    
    # Use a simple uniform distribution for testing
    # F(u) = (u + 0.5) for u ∈ [-0.5, 0.5]
    # F'(u) = 1
    # φ(u) = u - (1 - (u + 0.5))/1 = u - 0.5 + u = 2u - 0.5
    # φ^{-1}(s) = (s + 0.5)/2
    
    def F_simple(u):
        u = np.clip(u, -0.5, 0.5)
        return u + 0.5
    
    def F_prime_simple(u):
        return np.ones_like(u)
    
    # Create interpolators
    from scipy.interpolate import interp1d
    u_grid = np.linspace(-0.5, 0.5, 100)
    F_interp = interp1d(u_grid, F_simple(u_grid), kind='linear',
                        bounds_error=False, fill_value=(0, 1))
    F_prime_interp = interp1d(u_grid, F_prime_simple(u_grid), kind='linear',
                               bounds_error=False, fill_value=1)
    
    # Create inverter
    inverter = PhiInverter(F_interp, F_prime_interp)
    
    # Test phi function
    print("\n--- Testing φ(u) ---")
    u_test = np.array([-0.3, -0.1, 0.0, 0.1, 0.3])
    phi_vals = inverter.phi(u_test)
    phi_true = 2 * u_test - 0.5
    
    print("u\t\tφ(u)\t\tTrue\t\tError")
    for u, phi, phi_t in zip(u_test, phi_vals, phi_true):
        print(f"{u:.2f}\t\t{phi:.4f}\t\t{phi_t:.4f}\t\t{abs(phi-phi_t):.6f}")
    
    # Test phi inverse
    print("\n--- Testing φ^{-1}(s) ---")
    s_test = np.array([-0.8, -0.5, -0.2, 0.0, 0.2])
    
    print("\ns\t\tφ^{-1}(s)\tTrue\t\tError\t\tIterations")
    for s in s_test:
        phi_inv = inverter.solve_phi_inverse(s, verbose=False)
        phi_inv_true = (s + 0.5) / 2
        
        # Verify: φ(φ^{-1}(s)) ≈ s
        verification = inverter.phi(phi_inv)
        
        print(f"{s:.2f}\t\t{phi_inv:.4f}\t\t{phi_inv_true:.4f}\t\t"
              f"{abs(phi_inv-phi_inv_true):.6f}\t\tφ(φ^{{-1}})={verification:.4f}")
    
    # Test g function
    print("\n--- Testing g(u) = u + φ^{-1}(-u) ---")
    g_func = inverter.build_g_function()
    
    u_test = np.array([-0.2, -0.1, 0.0, 0.1, 0.2])
    print("\nu\t\tg(u)\t\tφ^{-1}(-u)\tu + φ^{-1}(-u)")
    for u in u_test:
        g_val = g_func(u)
        phi_inv = inverter.solve_phi_inverse(-u)
        expected = u + phi_inv
        print(f"{u:.2f}\t\t{g_val:.4f}\t\t{phi_inv:.4f}\t\t{expected:.4f}")
    
    print("\n✓ All tests completed")


if __name__ == "__main__":
    test_phi_inverter()
