"""
ETC (Explore-Then-Commit) Algorithm Implementation
===================================================

Clean implementation of the ETC algorithm, separated from the runner.
"""
from kernel_regression import KernelCDFEstimator
from phi_inverter import PhiInverter
import numpy as np
from typing import Dict, List, Tuple


class ExploreCommitAlgorithm:
    """
    Explore-Then-Commit algorithm for contextual dynamic pricing.
    
    Episodes k = 1, 2, ...:
    - Episode length: ℓ_k = 2^(k-1) * ℓ_0
    - Exploration length: a_k = ⌈(ℓ_k * d)^((2m+1)/(4m-1))⌉
    - Bandwidth: b_k = c_b * a_k^(-1/(2m+1))
    """
    
    def __init__(self, env, beta: float, ell_0: int, c_b: float = 3.0):
        """
        Parameters:
        -----------
        env : FGYEnvironment
            Environment with methods: sample_context, generate_outcome,
            compute_expected_revenue, compute_optimal_price
        beta : float
            Algorithm smoothness parameter (can be different from environment's m)
            Used for: exploration length a_k and bandwidth b_k
        ell_0 : int
            Initial episode length
        c_b : float
            Bandwidth constant
        """
        self.env = env
        self.beta = beta  # Algorithm smoothness (may differ from env.m)
        self.ell_0 = ell_0
        self.c_b = c_b
        
        # Derived parameters
        self.d = len(env.theta_0)  # Context dimension
        self.B = env.p_max         # Price upper bound
        
        # State
        self.t_global = 0
        self.k = 0  # Current episode
        
        # History
        self.history = {
            'contexts': [],
            'prices': [],
            'outcomes': [],
            'revenues': [],
            'optimal_revenues': [],
            'episodes': [],
            'phases': []
        }
        
        # Episode info
        self.episode_info = []
    
    def run(self, T: int, verbose: bool = False) -> Dict:
        """
        Run the algorithm for T rounds.
        
        Parameters:
        -----------
        T : int
            Time horizon
        verbose : bool
            Print progress information
        
        Returns:
        --------
        dict
            Results dictionary with regret, history, etc.
        """

        
        self.k = 1
        self.t_global = 0
        
        while self.t_global < T:
            self._run_episode(T, verbose)
            self.k += 1
        
        # Compute results
        total_revenue = sum(self.history['revenues'])
        total_optimal = sum(self.history['optimal_revenues'])
        total_regret = total_optimal - total_revenue
        

        
        return {
            'regret': total_regret,
            'relative_regret': total_regret / T,
            'total_revenue': total_revenue,
            'total_optimal': total_optimal,
            'n_episodes': self.k,
            'history': self.history,
            'episode_info': self.episode_info,
            'theta_true': self.env.theta_0.copy(),
            'beta': self.beta,  # Algorithm smoothness
            'ell_0': self.ell_0,
            'c_b': self.c_b
        }
    
    def _run_episode(self, T: int, verbose: bool = False):
        """Run a single episode with three phases."""
        # Compute episode parameters
        ell_k = (2 ** (self.k - 1)) * self.ell_0
        T_remaining = T - self.t_global
        ell_k = min(ell_k, T_remaining)
        
        # Use beta for exploration length computation
        a_k = int(np.ceil((ell_k * self.d) ** ((2 * self.beta + 1) / (4 * self.beta - 1))))
        a_k = min(a_k, ell_k)
        
        # Use beta for bandwidth computation
        b_k = self.c_b * (a_k ** (-1 / (2 * self.beta + 1)))
        

        
        # Phase 1: Exploration
        X_explore, p_explore, y_explore = self._exploration_phase(a_k)
        
        # Phase 2: Estimation

        
        Theta_hat_k = self._estimate_parameters(X_explore, y_explore, verbose)
        g_hat_k = self._estimate_pricing_function(
            X_explore, p_explore, y_explore, Theta_hat_k, b_k
        )
        
        # Record episode info
        self.episode_info.append({
            'k': self.k,
            'ell_k': ell_k,
            'a_k': a_k,
            'b_k': b_k,
            'Theta_hat': Theta_hat_k.copy(),
            'exploration_rate': y_explore.mean()
        })
        
        # Phase 3: Exploitation
        exploit_length = ell_k - a_k
        if exploit_length > 0:
            if verbose:
                print(f"  利用阶段: {exploit_length} 轮")
            self._exploitation_phase(exploit_length, Theta_hat_k, g_hat_k, T)
    
    def _exploration_phase(self, a_k: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Phase 1: Random exploration.
        
        Returns: (X_explore, p_explore, y_explore)
        """
        X_explore = []
        p_explore = []
        y_explore = []
        
        for i in range(a_k):
            c_t = self.env.sample_context()
            x_tilde = np.concatenate([c_t, [1.0]])
            
            # Uniform random pricing
            p_t = np.random.uniform(0, self.B)
            
            # Observe outcome
            y_t = self.env.generate_outcome(c_t, p_t)
            
            # Store exploration data
            X_explore.append(x_tilde)
            p_explore.append(p_t)
            y_explore.append(y_t)
            
            # Record history
            self._record_step(c_t, p_t, y_t, 'explore')
        
        return (np.array(X_explore), np.array(p_explore), np.array(y_explore))
    
    def _estimate_parameters(self, X: np.ndarray, y: np.ndarray, 
                            verbose: bool = False) -> np.ndarray:
        """
        Estimate Θ via OLS: Θ̂ = argmin Σ(B*y_t - Θ^T x_t)²
        """
        Y_tilde = self.B * y
        Theta_hat, _, _, _ = np.linalg.lstsq(X, Y_tilde, rcond=None)
        
        if verbose and hasattr(self.env, 'Theta_true'):
            error = np.linalg.norm(Theta_hat - self.env.Theta_true)
            print(f"  Θ̂_k = {Theta_hat}")
            print(f"  估计误差: {error:.6f}")
        
        return Theta_hat
    
    def _estimate_pricing_function(self, X: np.ndarray, p: np.ndarray, 
                                   y: np.ndarray, Theta_hat: np.ndarray,
                                   b_k: float):
        """
        Build pricing function via kernel regression + phi inversion.
        """
        # Import here to avoid circular dependencies
        from kernel_regression import KernelCDFEstimator
        from phi_inverter import PhiInverter
        
        # Step 1: Compute residuals
        w_hat = p - X @ Theta_hat
        
        # Step 2: Kernel regression
        # Note: Use rounded beta for kernel order (must be 2, 4, or 6)
        kernel_order = 2 if self.beta <= 3 else (4 if self.beta <= 5 else 6)
        u_grid = np.linspace(-0.5, 0.5, 200)
        estimator = KernelCDFEstimator(kernel_order, b_k)
        F_hat, F_prime_hat = estimator.fit(w_hat, y, u_grid)
        
        # Step 3: Create interpolators
        F_interp, F_prime_interp = estimator.create_interpolators(
            u_grid, F_hat, F_prime_hat
        )
        
        # Step 4: Build pricing function
        inverter = PhiInverter(F_interp, F_prime_interp)
        g_hat = inverter.build_g_function()
        
        return g_hat
    
    def _exploitation_phase(self, exploit_length: int, Theta_hat: np.ndarray,
                           g_hat, T: int):
        """
        Phase 3: Exploit using learned pricing function.
        """
        for i in range(exploit_length):
            if self.t_global >= T:
                break
            
            c_t = self.env.sample_context()
            x_tilde = np.concatenate([c_t, [1.0]])
            
            # Compute μ̂_t
            mu_hat_t = x_tilde @ Theta_hat
            
            # Use learned pricing function
            p_tilde = g_hat(mu_hat_t)
            p_t = np.clip(p_tilde, 0, self.B)
            
            # Observe outcome
            y_t = self.env.generate_outcome(c_t, p_t)
            
            # Record history
            self._record_step(c_t, p_t, y_t, 'exploit')
    
    def _record_step(self, c: np.ndarray, p: float, y: int, phase: str):
        """Record a single time step in history."""
        _, r_opt = self.env.compute_optimal_price(c)
        r_actual = self.env.compute_expected_revenue(c, p)
        
        self.history['contexts'].append(c)
        self.history['prices'].append(p)
        self.history['outcomes'].append(y)
        self.history['revenues'].append(r_actual)
        self.history['optimal_revenues'].append(r_opt)
        self.history['episodes'].append(self.k)
        self.history['phases'].append(phase)
        
        self.t_global += 1