import numpy as np
from typing import Union, Callable
import matplotlib.pyplot as plt
from scipy import stats
from scipy.stats import truncnorm
from scipy.optimize import brentq
import os


def F_Y_upper(y: Union[float, np.ndarray], 
              F_Y: Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]],
              a_X: float,
              b_X: float,
              b_Y: float,
              b_min: float,
              epsilon: float) -> Union[float, np.ndarray]:
    """
    Implements the upper bound function F_Y^U(y) as defined in the LaTeX equation.
    q_1_epsilon is computed as the (1-epsilon) quantile of F_Y.
    """
    # Compute q_1_epsilon (1-epsilon quantile)
    y_grid = np.linspace(a_X, b_Y, 1000)
    F_Y_vals = F_Y(y_grid)
    q_1_epsilon = np.interp(epsilon, F_Y_vals, y_grid)

    # Convert to numpy array for vectorized operations
    if isinstance(y, (int, float)):
        y = np.array([y])
        return_scalar = True
    else:
        y = np.array(y)
        return_scalar = False
    
    result = np.zeros_like(y, dtype=float)
    
    # Case 1: y < max(a_X, q_1_epsilon)
    mask1 = y < np.maximum(a_X, q_1_epsilon)
    result[mask1] = 0.0
    
    # Case 2: y ∈ [max(a_X, q_1_epsilon), b_min)
    mask2 = (y >= np.maximum(a_X, q_1_epsilon)) & (y < b_min)
    result[mask2] = F_Y(y[mask2]) - epsilon
    
    # Case 3: y ∈ [b_min, b_X)
    mask3 = (y >= b_min) & (y < b_X)
    result[mask3] = 1.0 - epsilon
    
    # Case 4: y ≥ max(b_X, b_Y)
    mask4 = y >= np.maximum(b_X, b_Y)
    result[mask4] = 1.0
    
    if return_scalar:
        return float(result[0])
    else:
        return result


def F_Y_lower(y: Union[float, np.ndarray], 
              F_Y: Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]],
              a_X: float,
              b_X: float,
              a_Y: float,
              b_Y: float,
              a_min: float,
              epsilon: float) -> Union[float, np.ndarray]:
    """
    Implements the lower bound function F_Y^L(y) as defined in the LaTeX equation.
    q_1_epsilon is computed as the (1-epsilon) quantile of F_Y.
    q_epsilon is computed as the epsilon quantile of F_Y.
    """
    # Compute quantiles
    y_grid = np.linspace(a_min, b_Y, 1000)
    F_Y_vals = F_Y(y_grid)
    q_1_epsilon = np.interp(epsilon, F_Y_vals, y_grid)  # epsilon quantile
    q_epsilon = np.interp(1 - epsilon, F_Y_vals, y_grid)  # (1-epsilon) quantile

    # Convert to numpy array for vectorized operations
    if isinstance(y, (int, float)):
        y = np.array([y])
        return_scalar = True
    else:
        y = np.array(y)
        return_scalar = False
    
    result = np.zeros_like(y, dtype=float)
    
    # Case 1: y < a_min
    mask1 = y < a_min
    result[mask1] = 0.0
    
    # Case 2: y ∈ [a_min, a_Y)
    mask2 = (y >= a_min) & (y < a_Y)
    result[mask2] = epsilon
    
    # Case 3: y ∈ [a_Y, min(q_1_epsilon, b_X))
    mask3 = (y >= a_Y) & (y < np.minimum(q_epsilon, b_X))
    result[mask3] = F_Y(y[mask3]) + epsilon
    
    # Case 4: y ≥ min(q_epsilon, b_X)
    mask4 = y >= np.minimum(q_epsilon, b_X)
    result[mask4] = 1.0
    
    if return_scalar:
        return float(result[0])
    else:
        return result


def create_truncated_normal_cdf(a: float = -1.0, b: float = 1.0, loc: float = 0.0, scale: float = 0.5):
    """
    Creates a CDF function for a truncated normal distribution.
    
    Parameters:
    -----------
    a : float
        Lower bound of truncation
    b : float
        Upper bound of truncation
    loc : float
        Location parameter (mean of underlying normal)
    scale : float
        Scale parameter (std of underlying normal)
    
    Returns:
    --------
    callable
        CDF function that takes y values and returns F_Y(y)
    """
    def F_Y(y):
        # Handle scalar and array inputs
        if isinstance(y, (int, float)):
            y = np.array([y])
            return_scalar = True
        else:
            y = np.array(y)
            return_scalar = False
        
        # Create truncated normal distribution
        tn = truncnorm(a=(a - loc) / scale, b=(b - loc) / scale, loc=loc, scale=scale)
        
        # Calculate CDF values
        result = tn.cdf(y)
        
        # Clip to ensure values are within [0, 1]
        result = np.clip(result, 0.0, 1.0)
        
        if return_scalar:
            return float(result[0])
        else:
            return result
    
    return F_Y


def plot_F_Y_functions():
    """
    Plot the F_Y_upper and F_X_upper functions with truncated normal distribution CDFs.
    """
    # Create single figure
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Parameters for the truncated normal distribution
    a_Y = -1.0  # Lower bound of Y
    b_Y = 1.0   # Upper bound of Y
    
    # Parameters for X (as specified by user)
    a_X = -1.0
    b_X = 1.0
    a_min = min(a_X, a_Y)  # Minimum of lower bounds
    b_min = min(b_X, b_Y)  # Minimum of upper bounds
    
    # Risk parameter
    epsilon = 0.1
    
    # Create plotting range
    y_plot = np.linspace(-1.5, 1.5, 2000)  # More points for smoother curves
    
    # F_Y: Original truncated normal (scale=0.3)
    F_Y = create_truncated_normal_cdf(a=a_Y, b=b_Y, loc=0.0, scale=0.3)
    
    # Calculate function values for F_Y
    F_Y_original = F_Y(y_plot)
    F_Y_U_values = F_Y_upper(y_plot, F_Y, a_X, b_X, b_Y, b_min, epsilon)
    F_Y_L_values = F_Y_lower(y_plot, F_Y, a_X, b_X, a_Y, b_Y, a_min, epsilon)
    
    # F_X: Truncated normal (scale=scale_X)
    scale_X = 0.4
    F_X = create_truncated_normal_cdf(a=a_X, b=b_X, loc=0.0, scale=scale_X)
    
    # Calculate function values for F_X
    F_X_original = F_X(y_plot)
    
    F_X_U_values = F_Y_upper(y_plot, F_X, a_X, b_X, b_Y, b_min, epsilon)
    F_X_L_values = F_Y_lower(y_plot, F_X, a_X, b_X, a_Y, b_Y, a_min, epsilon)
    
    # Plot F_Y functions
    plt.plot(y_plot, F_Y_original, 'b-', linewidth=4, label=r'$F_Y$')
    plt.plot(y_plot, F_Y_U_values, 'r-', linewidth=4, label=r'$F_Y^U$')
    plt.plot(y_plot, F_Y_L_values, 'g-', linewidth=4, label=r'$F_Y^L$')
    
    # Plot F_X function (original only)
    plt.plot(y_plot, F_X_original, 'b--', linewidth=4, label=r'$F_X$')
    
    # Customize the plot
    plt.xlabel('y', fontsize=42, fontweight='bold')
    plt.ylabel('Cumulative Probability', fontsize=42, fontweight='bold')
    plt.title(f'$F_Y$ and $F_X$ with Upper and Lower Bounds', fontsize=47, fontweight='bold', pad=20)
    plt.grid(True, alpha=0.3)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.xlim(-1.5, 1.5)
    plt.ylim(-0.1, 1.1)
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    # Save the plot
    plot_dir = "./dist_disc_plots"
    os.makedirs(plot_dir, exist_ok=True)
    plot_path = os.path.join(plot_dir, 'F_Y_bounds_plot.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Plot saved as '{plot_path}'")
    plt.close()
    
    # Print summary statistics
    print(f"\nSummary:")
    print(f"F_Y (σ=0.3):")
    print(f"  Original F_Y range: [{F_Y_original.min():.3f}, {F_Y_original.max():.3f}]")
    print(f"  Upper bound F_Y^U range: [{F_Y_U_values.min():.3f}, {F_Y_U_values.max():.3f}]")
    print(f"  Lower bound F_Y^L range: [{F_Y_L_values.min():.3f}, {F_Y_L_values.max():.3f}]")
    print(f"F_X (σ={scale_X}):")
    print(f"  Original F_X range: [{F_X_original.min():.3f}, {F_X_original.max():.3f}]")


# Example usage and testing
if __name__ == "__main__":
    # Run the plotting function
    plot_F_Y_functions()
    
    # Also run the original test
    print("\n" + "="*50)
    print("Original test with simple parameters:")
    
    # Example parameters
    a_X = 0.0
    b_X = 10.0
    b_Y = 8.0
    b_min = min(b_X, b_Y)
    epsilon = 0.1
    
    # Example CDF function (standard normal CDF)
    def F_Y(y):
        return 0.5 * (1 + np.tanh(y / 2))  # Approximate normal CDF
    
    # Test the function
    y_values = np.linspace(-2, 12, 100)
    F_Y_U_values = F_Y_upper(y_values, F_Y, a_X, b_X, b_Y, b_min, epsilon)
    F_Y_L_values = F_Y_lower(y_values, F_Y, a_X, b_X, 0.0, b_Y, b_min, epsilon)
    F_Y_original = F_Y(y_values)
    
    print("Function implemented successfully!")
    print(f"Tested with {len(y_values)} values")
    print(f"Upper bound range: [{F_Y_U_values.min():.3f}, {F_Y_U_values.max():.3f}]")
    print(f"Lower bound range: [{F_Y_L_values.min():.3f}, {F_Y_L_values.max():.3f}]")
