

import numpy as np
from scipy.integrate import quad
from scipy.stats import rv_continuous
from typing import Tuple, Callable


class PolynomialNoise(rv_continuous):
    """
    m-smooth polynomial noise distribution
    
    Density: f_Z^(m)(z) = C_m * (1/4 - z²)^(m/2) * 1_{|z| ≤ 1/2}
    
    This is the demand noise ξ_t in the valuation model:
        u_t = c_t^T θ + ξ_t
    """
    
    def __init__(self, m: int = 2):
        """
        Parameters:
        - m: smoothness parameter (must be even, typically 2, 4, or 6)
        """
        assert m % 2 == 0 and m >= 2, "m must be even and >= 2"
        self.m = m
        
        # Compute normalizing constant C_m
        # C_m = 1 / ∫_{-1/2}^{1/2} (1/4 - u²)^(m/2) du
        integrand = lambda u: (0.25 - u**2)**(m/2)
        self.normalization, _ = quad(integrand, -0.5, 0.5)
        self.C_m = 1.0 / self.normalization
        
        # Initialize parent with support [-1/2, 1/2]
        super().__init__(a=-0.5, b=0.5, name=f'polynomial_m{m}')
    
    def _pdf(self, z):
        """Probability density function"""
        # f_Z^(m)(z) = C_m * (1/4 - z²)^(m/2)
        return self.C_m * (0.25 - z**2)**(self.m/2)
    
    def _cdf(self, z):
        """Cumulative distribution function"""
        # F_m(z) = ∫_{-1/2}^z f_Z^(m)(u) du
        if z <= -0.5:
            return 0.0
        elif z >= 0.5:
            return 1.0
        else:
            integrand = lambda u: self._pdf(u)
            result, _ = quad(integrand, -0.5, z)
            return result
    
    def survival(self, s: float) -> float:
        """
        Survival function g_m(s) = P(ξ ≥ s) = 1 - F_m(s)
        
        This is the demand curve in the single-index model:
            P(y=1|c,p) = g_m(p - c^T θ)
        """
        return 1.0 - self._cdf(s)
    
    def sample(self) -> float:
        """Sample a single realization"""
        return self.rvs()


class PolynomialContext(rv_continuous):
    """
    m-smooth polynomial context distribution
    
    Density: f_X^(m)(x) = C'_m * (2/3 - x²)^(m+1) * 1_{|x| ≤ √(2/3)}
    
    Each coordinate of c_t is drawn i.i.d. from this distribution.
    """
    
    def __init__(self, m: int = 2):
        """
        Parameters:
        - m: smoothness parameter (must be even, typically 2, 4, or 6)
        """
        assert m % 2 == 0 and m >= 2, "m must be even and >= 2"
        self.m = m
        self.support_bound = np.sqrt(2.0/3.0)  # √(2/3) ≈ 0.8165
        
        # Compute normalizing constant C'_m
        # C'_m = 1 / ∫_{-√(2/3)}^{√(2/3)} (2/3 - u²)^(m+1) du
        integrand = lambda u: (2.0/3.0 - u**2)**(m+1)
        self.normalization, _ = quad(integrand, -self.support_bound, self.support_bound)
        self.C_prime_m = 1.0 / self.normalization
        
        # Initialize parent with support [-√(2/3), √(2/3)]
        super().__init__(a=-self.support_bound, b=self.support_bound, 
                        name=f'polynomial_context_m{m}')
    
    def _pdf(self, x):
        """Probability density function"""
        # f_X^(m)(x) = C'_m * (2/3 - x²)^(m+1)
        return self.C_prime_m * (2.0/3.0 - x**2)**(self.m + 1)


class FGYEnvironment:
    """
    Fan-Guo-Yu experimental environment with m-smooth polynomial distributions
    
    Model: E[y|c,p] = g_m(p - c^T θ)
    where:
    - c ∈ R^(d-1): context with i.i.d. coordinates ~ f_X^(m)
    - p ∈ [0, p_max]: price
    - θ ∈ R^(d-1): true parameter (this is β_0 in your notation)
    - α_0 = 1 (fixed)
    - g_m: survival function of ξ ~ f_Z^(m)
    
    Compatible interface with TrueEnvironment from experiment_environment.py
    """
    
    def __init__(
        self,
        theta_0: np.ndarray,
        m: int = 2,
        p_min: float = 0.0,
        p_max: float = 1.0
    ):
        """
        Parameters:
        - theta_0: true parameter vector (d-1 dimensional, this is β_0)
        - m: smoothness parameter (2, 4, or 6)
        - p_min, p_max: price bounds
        """
        self.theta_0 = theta_0
        self.alpha_0 = 1.0  # Fixed as per your model
        self.m = m
        self.p_min = p_min
        self.p_max = p_max
        self.d = len(theta_0) + 1  # context dimension + 1
        
        # Initialize distributions
        self.noise_dist = PolynomialNoise(m=m)
        self.context_dist = PolynomialContext(m=m)
        
        # Store bounds for reference
        self.context_min = -self.context_dist.support_bound
        self.context_max = self.context_dist.support_bound
        
        print(f"FGY Environment (m={m}):")
        print(f"  α₀ = 1.0 (fixed)")
        print(f"  θ₀ = {theta_0}")
        print(f"  Dimension d = {self.d}")
        print(f"  Price range: [{p_min:.4f}, {p_max:.4f}]")
        print(f"  Context support: [{self.context_min:.4f}, {self.context_max:.4f}]")
        print(f"  Noise support: [-0.5, 0.5]")
        print(f"  Smoothness: m = {m} (β = m in Hölder sense)")
    
    def sample_context(self) -> np.ndarray:
        """
        Sample context c ~ f_X^(m)
        
        Each coordinate is drawn i.i.d. from f_X^(m)(x) with support [-√(2/3), √(2/3)]
        
        Returns: c ∈ R^(d-1)
        """
        return self.context_dist.rvs(size=self.d - 1)
    
    def generate_outcome(self, c: np.ndarray, p: float) -> int:
        """
        Generate purchase outcome y ∈ {0, 1}
        
        Process:
        1. Sample noise ξ ~ f_Z^(m)
        2. Compute s_true = p - c^T θ₀  (α=1)
        3. y = 1{ξ ≥ s_true}
        
        Parameters:
        - c: context vector (d-1 dimensional)
        - p: price
        
        Returns: y ∈ {0, 1}
        """
        xi = self.noise_dist.rvs()
        s_true = p - np.dot(c, self.theta_0)
        y = 1 if xi >= s_true else 0
        return y
    
    def compute_expected_revenue(self, c: np.ndarray, p: float) -> float:
        """
        Compute expected revenue r(c,p) = p * g_m(p - c^T θ₀)
        
        where g_m(s) = P(ξ ≥ s) is the survival function
        
        Parameters:
        - c: context vector
        - p: price
        
        Returns: E[p*y | c, p]
        """
        s = p - np.dot(c, self.theta_0)
        g_s = self.noise_dist.survival(s)
        return p * g_s
    
    def compute_optimal_price(
        self, 
        c: np.ndarray, 
        n_grid: int = 100
    ) -> Tuple[float, float]:
        """
        Compute optimal price and revenue for given context
        
        Uses grid search: p* = argmax_{p ∈ [p_min, p_max]} r(c, p)
        
        Parameters:
        - c: context vector
        - n_grid: number of grid points for search
        
        Returns: (p*, r*)
        """
        prices = np.linspace(self.p_min, self.p_max, n_grid)
        revenues = np.array([
            self.compute_expected_revenue(c, p) for p in prices
        ])
        
        best_idx = np.argmax(revenues)
        p_star = prices[best_idx]
        r_star = revenues[best_idx]
        
        return p_star, r_star
    
    def get_g_function(self) -> Callable[[float], float]:
        """Return the demand curve g_m(s) = P(ξ ≥ s)"""
        return self.noise_dist.survival


def create_fgy_environment(
    d: int = 3,
    m: int = 2,
    theta_scale: float = 1.0
) -> FGYEnvironment:
    """
    Create FGY environment with standard configuration
    
    Parameters:
    - d: dimension (d=3 matches FGY paper)
    - m: smoothness parameter (2, 4, or 6)
    - theta_scale: scaling factor for θ₀ (default: all 0.5 like your current setup)
    
    Configuration (d=3):
    - α₀ = 1 (fixed)
    - θ₀ = [0.5, 0.5] * theta_scale
    - Price range: [0, 1]
    - Context: c_j ~ f_X^(m), support [-√(2/3), √(2/3)]
    - Noise: ξ ~ f_Z^(m), support [-1/2, 1/2]
    """
    if d != 3:
        print(f"Warning: d={d} specified, but FGY paper uses d=3")
    
    # Standard configuration matching your experiment_environment.py
    theta_0 = theta_scale * np.array([0.5, 0.5])
    p_min = 0.0
    p_max = 1.0
    
    return FGYEnvironment(
        theta_0=theta_0,
        m=m,
        p_min=p_min,
        p_max=p_max
    )


def get_fgy_pilot(
    theta_0: np.ndarray,
    T: int
) -> np.ndarray:
    """
    Generate pilot parameter (fixed error version)
    
    Method: θ_pilot = θ₀ + (1/√T) * direction
    where direction is a fixed unit vector
    
    This matches get_paper_pilot() from experiment_environment.py
    
    Parameters:
    - theta_0: true parameter
    - T: total time steps
    
    Returns: theta_pilot
    """
    error_magnitude = 1.0 / np.sqrt(T)
    
    # Use fixed direction (normalized [1, 1, ...])
    d = len(theta_0)
    direction = np.ones(d) / np.sqrt(d)
    
    theta_pilot = theta_0 + error_magnitude * direction
    
    actual_error = np.linalg.norm(theta_pilot - theta_0)
    
    print(f"\nPilot parameter generation:")
    print(f"  Target error: {error_magnitude:.6f}")
    print(f"  Actual error: {actual_error:.6f}")
    print(f"  θ₀:      {theta_0}")
    print(f"  θ_pilot: {theta_pilot}")
    
    return theta_pilot


# ============ Testing and Comparison ============

if __name__ == "__main__":
    print("="*80)
    print("FGY Environment Testing")
    print("="*80)
    
    # Test different smoothness levels
    for m in [2, 4, 6]:
        print(f"\n{'='*80}")
        print(f"Testing m = {m} (Hölder β = {m})")
        print(f"{'='*80}")
        
        # Create environment
        env = create_fgy_environment(d=3, m=m)
        
        # Test sampling
        print("\nSampling test:")
        c_sample = env.sample_context()
        print(f"  Context sample: {c_sample}")
        print(f"  Context bounds check: all in [{env.context_min:.4f}, {env.context_max:.4f}]? "
              f"{np.all(np.abs(c_sample) <= env.context_max)}")
        
        # Test noise
        print("\nNoise test:")
        xi_samples = [env.noise_dist.rvs() for _ in range(100)]
        print(f"  100 noise samples: mean={np.mean(xi_samples):.4f}, std={np.std(xi_samples):.4f}")
        print(f"  Support check: all in [-0.5, 0.5]? {np.all(np.abs(xi_samples) <= 0.5)}")
        
        # Test demand curve
        print("\nDemand curve g_m(s) = P(ξ ≥ s):")
        for s in [-0.5, 0.0, 0.5]:
            g_s = env.noise_dist.survival(s)
            print(f"  g_m({s:+.1f}) = {g_s:.4f}")
        
        # Test optimal pricing
        print("\nOptimal pricing test:")
        c_test = np.array([0.3, 0.3])
        p_opt, r_opt = env.compute_optimal_price(c_test)
        print(f"  Context: {c_test}")
        print(f"  Optimal price: {p_opt:.4f}")
        print(f"  Optimal revenue: {r_opt:.4f}")
        
        # Test pilot generation
        print("\nPilot generation test:")
        T = 1000
        theta_pilot = get_fgy_pilot(env.theta_0, T)
    
    print("\n" + "="*80)
    print("All tests completed!")
    print("="*80)
