

import numpy as np
from scipy.optimize import minimize
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')


class LocalPolyBin:
    
    def __init__(self, bin_range: Tuple[float, float], ell: int):
        self.a, self.b = bin_range 
        self.ell = ell  
        
        self.sample_indices = [] 
        self.n_samples = 0
        
        self.x_bar = None  # ∈ R^{d+1}
        
        self.theta_hat = None  # ∈ R^{d+1}
        
        self.Lambda = None  # (ell+1) x (ell+1)
        self.Lambda_inv = None
        self.S_yU = None  # ell+1
        self.S_UX = None  # (ell+1) x ell
        self.Sigma_v = None  # ell x ell
        
        self.s_j = 0  
        self.L_j = 1  


class EpochOperator:
    
    def __init__(self, tau: int, N_tau: int, h_tau: float, M_tau: int,
                 index_min: float, index_max: float, ell: int, beta: float):
        self.tau = tau
        self.N_tau = N_tau
        self.h_tau = h_tau
        self.M_tau = M_tau
        self.ell = ell
        self.beta = beta  
        
        self.bins = self._create_bins(index_min, index_max, M_tau)
        self.bin_objects = [LocalPolyBin(b, ell) for b in self.bins]
        
        self.K = None
        self.kappa = None
    
    def _create_bins(self, s_min: float, s_max: float, M: int) -> List[Tuple[float, float]]:
        edges = np.linspace(s_min, s_max, M + 1)
        return [(edges[i], edges[i+1]) for i in range(M)]
    
    def find_bin(self, s: float) -> Optional[int]:
        for j, (a, b) in enumerate(self.bins):
            if a <= s < b or (j == len(self.bins)-1 and abs(s - b) < 1e-10):
                return j
        return None
    
    def g_hat(self, c: np.ndarray, p: float, theta_bar0: np.ndarray) -> float:
        x = np.concatenate([c, [p]])
        s = x @ theta_bar0
        j = self.find_bin(s)
        
        if j is None or self.bin_objects[j].theta_hat is None:
            return 0.5 
        
        bin_obj = self.bin_objects[j]
        if bin_obj.n_samples == 0:
            return 0.5
        
        U_j = self._compute_U_j(x, bin_obj)
        g_val = U_j @ bin_obj.Lambda_inv @ bin_obj.S_yU
        return float(np.clip(g_val, 0, 1))
    
    def CB(self, c: np.ndarray, p: float, theta_bar0: np.ndarray,
           delta: float, c0: float, C1: float, zeta: float) -> float:
        x = np.concatenate([c, [p]])
        s = x @ theta_bar0
        j = self.find_bin(s)
        
        if j is None or self.bin_objects[j].theta_hat is None:
            return 1.0  
        
        bin_obj = self.bin_objects[j]
        if bin_obj.n_samples == 0 or self.ell == 0:
            return 1.0
        
        try:
            U_j = self._compute_U_j(x, bin_obj)
            X_j = self._compute_X_j(x, bin_obj)
            v_j = self._compute_v_j(U_j, X_j, bin_obj)
            

            if bin_obj.Sigma_v.shape[0] > 0:
                v_dim = bin_obj.Sigma_v.shape[0]
                Sigma_reg = bin_obj.Sigma_v + zeta * np.eye(v_dim)
                Sigma_inv = np.linalg.inv(Sigma_reg)
            else:
                Sigma_inv = np.zeros((0, 0))
                v_j = np.array([])

            Lambda_inv = bin_obj.Lambda_inv
            
            if len(v_j) > 0:
                norm_v = np.sqrt(max(0, v_j @ Sigma_inv @ v_j))
            else:
                norm_v = 0.0
            norm_U = np.sqrt(max(0, U_j @ Lambda_inv @ U_j))
            

            h = self.N_tau ** (-1.0 / (2 * self.beta + 1))
            sqrt_term = np.sqrt(np.log(1/delta)) + np.sqrt(bin_obj.n_samples) * (h ** self.beta)
            bias_term = self.N_tau ** (-self.beta / (2*self.beta + 1))
            Err_j = sqrt_term * (norm_v + norm_U) + bias_term
            
            H = np.diag([self.M_tau**(-i) for i in range(self.ell + 1)])
            Lambda_scaled = H @ bin_obj.Lambda @ H
            min_eig = np.linalg.eigvalsh(Lambda_scaled)[0]
            
            if min_eig >= c0 * np.sqrt(bin_obj.n_samples):
                return float(C1 * Err_j)
            else:
                return 1.0
        except:
            return 1.0
    
    def apply_operator(self, c: np.ndarray, interval: Tuple[float, float],
                       theta_bar0: np.ndarray, p_max: float,
                       delta: float, c0: float, C1: float, zeta: float) -> Tuple[float, float]:

        pi, pi_bar = interval
        
        if pi >= pi_bar or pi_bar <= 0 or pi >= p_max:
            return (0.0, p_max)
        
        pi = max(0, pi)
        pi_bar = min(p_max, pi_bar)
        
        L = pi_bar - pi
        if L < 1e-6: 
            return (pi, pi_bar)
        
        w = L / self.K
        
        best_reward = -np.inf
        best_k = 0
        max_uncertainty = 0
        

        n_grid = min(int(np.sqrt(self.N_tau)), 50)
        
        for k in range(self.K):
            J_k_start = pi + k * w
            J_k_end = pi + (k + 1) * w
            
            prices = np.linspace(J_k_start, J_k_end, n_grid)
            
            rewards = []
            uncertainties = []
            for p_test in prices:
                if 0 <= p_test <= p_max:
                    g = self.g_hat(c, p_test, theta_bar0)
                    cb = self.CB(c, p_test, theta_bar0, delta, c0, C1, zeta)
                    rewards.append(p_test * g)
                    uncertainties.append(cb)
            
            if len(rewards) > 0:
                avg_reward = np.mean(rewards)
                avg_uncertainty = np.mean(uncertainties)
                
                if avg_reward > best_reward:
                    best_reward = avg_reward
                    best_k = k
                
                max_uncertainty = max(max_uncertainty, avg_uncertainty)
        
        J_best_start = pi + best_k * w
        J_best_end = pi + (best_k + 1) * w
        m = (J_best_start + J_best_end) / 2
        
        Delta_hat = self.kappa * np.sqrt(w**2 + max_uncertainty**2)
        
        pi_new = max(0, m - Delta_hat)
        pi_bar_new = min(p_max, m + Delta_hat)
        
        return (pi_new, pi_bar_new)
    
    def _compute_U_j(self, x: np.ndarray, bin_obj: LocalPolyBin) -> np.ndarray:
        Delta_j = (x - bin_obj.x_bar) @ bin_obj.theta_hat
        U_j = np.array([Delta_j**i for i in range(self.ell + 1)])
        return U_j
    
    def _compute_X_j(self, x: np.ndarray, bin_obj: LocalPolyBin) -> np.ndarray:
        if self.ell == 0:
            return np.array([])
        
        Delta_j = (x - bin_obj.x_bar) @ bin_obj.theta_hat
        diff = x - bin_obj.x_bar
        

        X_rows = []
        for i in range(1, self.ell + 1):
            row = i * (Delta_j**(i-1)) * diff
            X_rows.append(row)
        
        X_j = np.concatenate(X_rows)  
        return X_j
    
    def _compute_v_j(self, U_j: np.ndarray, X_j: np.ndarray, 
                     bin_obj: LocalPolyBin) -> np.ndarray:
        if self.ell == 0:
            return np.array([])
        
        # v_j = X_j - S_UX^T Λ^{-1} U_j
        v_j = X_j - bin_obj.S_UX.T @ bin_obj.Lambda_inv @ U_j
        return v_j


class LazyContextCache:
    
    def __init__(self, p_max: float, precision: int = 6):
        self.p_max = p_max
        self.precision = precision
        self.cache = {}
    
    def _context_key(self, c: np.ndarray) -> tuple:
        return tuple(np.round(c, decimals=self.precision))
    
    def get_interval(self, c: np.ndarray, current_epoch: int,
                     epoch_operators: List[EpochOperator],
                     theta_bar0: np.ndarray, delta: float,
                     c0: float, C1: float, zeta: float) -> Tuple[float, float]:

        key = self._context_key(c)
        
        if key not in self.cache:
            self.cache[key] = {
                'lower': 0.0,
                'upper': self.p_max,
                'last_epoch': 0
            }
        
        entry = self.cache[key]
        lower, upper = entry['lower'], entry['upper']
        last_epoch = entry['last_epoch']
        
        for ell in range(last_epoch + 1, current_epoch):
            op = epoch_operators[ell - 1] 
            lower, upper = op.apply_operator(
                c, (lower, upper), theta_bar0, self.p_max,
                delta, c0, C1, zeta
            )
        
        entry['lower'] = lower
        entry['upper'] = upper
        entry['last_epoch'] = current_epoch - 1
        
        return lower, upper


class ConstrainedLeastSquares:
    
    @staticmethod
    def solve(bin_obj: LocalPolyBin, theta_bar0: np.ndarray, eta: float,
              all_samples: List[Tuple[np.ndarray, int]], ell: int) -> np.ndarray:

        if len(bin_obj.sample_indices) == 0:
            return theta_bar0.copy()
        
        def loss(theta):
            Lambda = np.zeros((ell + 1, ell + 1))
            S_yU = np.zeros(ell + 1)
            
            for idx in bin_obj.sample_indices:
                x_t, y_t = all_samples[idx]
                Delta_j = (x_t - bin_obj.x_bar) @ theta
                U_j = np.array([Delta_j**i for i in range(ell + 1)])
                
                Lambda += np.outer(U_j, U_j)
                S_yU += y_t * U_j
            
            Lambda_reg = Lambda + 1e-6 * np.eye(ell + 1)
            
            try:
                Lambda_inv = np.linalg.inv(Lambda_reg)
            except:
                return 1e10
            
            loss_val = 0
            for idx in bin_obj.sample_indices:
                x_t, y_t = all_samples[idx]
                Delta_j = (x_t - bin_obj.x_bar) @ theta
                U_j = np.array([Delta_j**i for i in range(ell + 1)])
                g_hat = U_j @ Lambda_inv @ S_yU
                loss_val += (y_t - g_hat)**2
            
            return loss_val
        
        def constraint(theta):
            return eta - np.linalg.norm(theta - theta_bar0)
        
        try:
            result = minimize(
                loss,
                x0=theta_bar0.copy(),
                method='SLSQP',
                constraints={'type': 'ineq', 'fun': constraint},
                options={'maxiter': 100, 'ftol': 1e-6}
            )
            
            if result.success:
                return result.x
            else:
                return theta_bar0.copy()
        except:
            return theta_bar0.copy()


class LPSPAlgorithm:
    
    def __init__(self, env, theta_pilot: np.ndarray, hyperparams: dict):
        self.env = env
        

        self.theta_pilot = theta_pilot

        self.theta_bar0 = np.concatenate([-self.theta_pilot, [1.0]])
        
        self.T = hyperparams['T']
        self.d = len(self.theta_pilot) + 1 
        self.beta = hyperparams.get('beta', 2) 
        self.ell = int(np.floor(self.beta)) - 1
        
        self.eta = hyperparams['eta']
        self.N0 = hyperparams['n0']
        self.index_min = hyperparams['index_min']
        self.index_max = hyperparams['index_max']
        
        self.K = hyperparams['K']
        self.kappa = hyperparams['kappa']
        
        self.c0 = hyperparams.get('c0', 0.1)
        self.C1 = hyperparams.get('C1', 1.0)
        self.zeta = hyperparams['zeta']
        self.delta = hyperparams.get('delta', 1.0 / self.T) 
        
        self.epoch_operators = []
        self.lazy_cache = LazyContextCache(self.env.p_max)
        
        self.history = {
            'contexts': [],
            'prices': [],
            'outcomes': [],
            'rewards': [],
            'optimal_rewards': []
        }
        
        self.t_global = 0
    
    def run(self):
        tau = 1
        while self.t_global < self.T:
            N_tau = min(2**tau * self.N0, self.T - self.t_global)
            self.run_epoch(tau, N_tau)
            tau += 1
        
        total_reward = np.sum(self.history['rewards'])
        total_optimal = np.sum(self.history['optimal_rewards'])
        regret = total_optimal - total_reward
        relative_regret = regret / self.T
        

        
        return relative_regret
    
    def run_epoch(self, tau: int, N_tau: int):
        
        h_tau = N_tau ** (-1 / (2 * self.beta + 1))
        M_tau = max(3, int(np.ceil((self.index_max - self.index_min) / h_tau)))
        
        
        epoch_op = EpochOperator(
            tau, N_tau, h_tau, M_tau,
            self.index_min, self.index_max, self.ell, self.beta
        )
        epoch_op.K = self.K
        epoch_op.kappa = self.kappa
        
        epoch_samples = []
        for s in range(N_tau):
            c_t = self.env.sample_context()
            
            lower, upper = self.lazy_cache.get_interval(
                c_t, tau, self.epoch_operators,
                self.theta_bar0, self.delta,
                self.c0, self.C1, self.zeta
            )
            
            p_t = np.random.uniform(lower, upper)
            
            x_t = np.concatenate([c_t, [p_t]])
            s_t = x_t @ self.theta_bar0  # s = p - c^T theta_pilot
            j = epoch_op.find_bin(s_t)
            
            if j is not None:
                bin_obj = epoch_op.bin_objects[j]
                bin_obj.s_j += 1
                
                if bin_obj.s_j == bin_obj.L_j ** 2:
                    a_j, b_j = epoch_op.bins[j]
                    r = bin_obj.L_j % max(1, self.ell)
                    if self.ell > 0:
                        s_explore = a_j + (r + 0.5) / self.ell * (b_j - a_j)
                    else:
                        s_explore = (a_j + b_j) / 2
                    
                    p_t = s_explore + c_t @ self.theta_pilot
                    p_t = np.clip(p_t, 0, self.env.p_max)
                    bin_obj.L_j += 1
                    
                    x_t = np.concatenate([c_t, [p_t]])

            
            y_t = self.env.generate_outcome(c_t, p_t)
            
            r_t_expected = self.env.compute_expected_revenue(c_t, p_t)
            
            epoch_samples.append((x_t, y_t))
            self.history['contexts'].append(c_t)
            self.history['prices'].append(p_t)
            self.history['outcomes'].append(y_t)
            self.history['rewards'].append(r_t_expected) 
            
            _, r_opt = self.env.compute_optimal_price(c_t)
            self.history['optimal_rewards'].append(r_opt)
            
            self.t_global += 1
        
        self.joint_estimation(epoch_op, epoch_samples)
        
        self.epoch_operators.append(epoch_op)
        
        total_samples = sum(b.n_samples for b in epoch_op.bin_objects)
        non_empty_bins = sum(1 for b in epoch_op.bin_objects if b.n_samples > 0)
    
    def joint_estimation(self, epoch_op: EpochOperator, 
                        samples: List[Tuple[np.ndarray, int]]):
        for t, (x_t, y_t) in enumerate(samples):
            s_t = x_t @ self.theta_bar0
            j = epoch_op.find_bin(s_t)
            if j is not None:
                epoch_op.bin_objects[j].sample_indices.append(t)
        
        for j, bin_obj in enumerate(epoch_op.bin_objects):
            if len(bin_obj.sample_indices) == 0:
                continue
            
            bin_obj.n_samples = len(bin_obj.sample_indices)
            
            xs = [samples[idx][0] for idx in bin_obj.sample_indices]
            bin_obj.x_bar = np.mean(xs, axis=0)
            
            bin_obj.theta_hat = ConstrainedLeastSquares.solve(
                bin_obj, self.theta_bar0, self.eta, samples, self.ell
            )
            
            self._compute_bin_matrices(bin_obj, samples, epoch_op)
    
    def _compute_bin_matrices(self, bin_obj: LocalPolyBin,
                              samples: List[Tuple[np.ndarray, int]],
                              epoch_op: EpochOperator):
        Lambda = np.zeros((self.ell + 1, self.ell + 1))
        S_yU = np.zeros(self.ell + 1)
        
        if self.ell > 0:
            S_UX = np.zeros((self.ell + 1, self.ell * (self.d)))
        else:
            S_UX = np.zeros((1, 0))
        
        for idx in bin_obj.sample_indices:
            x_t, y_t = samples[idx]
            
            U_j = epoch_op._compute_U_j(x_t, bin_obj)
            X_j = epoch_op._compute_X_j(x_t, bin_obj)
            
            Lambda += np.outer(U_j, U_j)
            S_yU += y_t * U_j
            
            if self.ell > 0 and len(X_j) > 0:
                S_UX += np.outer(U_j, X_j)
        
        bin_obj.Lambda = Lambda
        bin_obj.Lambda_inv = np.linalg.inv(Lambda + 1e-6 * np.eye(self.ell + 1))
        bin_obj.S_yU = S_yU
        bin_obj.S_UX = S_UX
        
        if self.ell > 0:
            Sigma_v = np.zeros((self.ell * (self.d), self.ell * (self.d)))
            for idx in bin_obj.sample_indices:
                x_t, _ = samples[idx]
                U_j = epoch_op._compute_U_j(x_t, bin_obj)
                X_j = epoch_op._compute_X_j(x_t, bin_obj)
                v_j = epoch_op._compute_v_j(U_j, X_j, bin_obj)
                if len(v_j) > 0:
                    Sigma_v += np.outer(v_j, v_j)
            bin_obj.Sigma_v = Sigma_v
        else:
            bin_obj.Sigma_v = np.zeros((0, 0))
