import xgboost as xgb
import numpy as np
from typing import Optional, Sequence, Tuple, List, Dict, Any

class GPRegressor:
    def __init__(self, lengthscale: float = 1.0, variance: float = 1.0, noise_var: float = 1.0,
                 jitter: float = 1e-6):
        self.l = float(lengthscale)
        self.sigma_f = float(variance)
        self.noise_var = float(noise_var)
        self.jitter = float(jitter)

        # training data containers
        self.X = np.empty((0,))  # 1-D
        self.y = np.empty((0,))

        # cached Gram inverse (recompute on add)
        self._K_inv = None

    def kernel(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
        # squared exponential kernel; x,y are 1-D arrays
        x = np.atleast_1d(x).reshape(-1, 1)
        y = np.atleast_1d(y).reshape(-1, 1)
        d2 = (x - y.T) ** 2
        return (self.sigma_f * np.exp(-0.5 * d2 / (self.l ** 2)))

    def _recompute_inv(self):
        n = self.X.shape[0]
        if n == 0:
            self._K_inv = None
            return
        K = self.kernel(self.X, self.X)
        K = K + (self.noise_var + self.jitter) * np.eye(n)
        # use np.linalg.solve instead of explicit inverse where needed
        try:
            self._K_inv = np.linalg.inv(K)
        except np.linalg.LinAlgError:
            # fallback with pseudo-inverse
            self._K_inv = np.linalg.pinv(K)

    def add_data(self, x_new: Sequence[float], y_new: Sequence[float]):
        x_new = np.asarray(x_new).reshape(-1)
        y_new = np.asarray(y_new).reshape(-1)
        if x_new.shape[0] != y_new.shape[0]:
            raise ValueError("x_new and y_new must have same length")
        if self.X.size == 0:
            self.X = x_new.copy()
            self.y = y_new.copy()
        else:
            self.X = np.concatenate([self.X, x_new])
            self.y = np.concatenate([self.y, y_new])
        self._recompute_inv()

    def predict(self, xs: Sequence[float]) -> Tuple[np.ndarray, np.ndarray]:
        """Predict posterior mean and std at points xs (1-D sequence).

        Returns (mu, std) arrays same shape as xs.
        If no training data, returns mu=0, std = k(x,x) (i.e. prior std).
        """
        xs = np.asarray(xs).reshape(-1)
        if self.X.size == 0:
            # prior
            kxx = np.full_like(xs, fill_value=self.sigma_f)
            return np.zeros_like(xs), np.sqrt(kxx)

        Kx = self.kernel(self.X, xs)  # shape (n_train, n_query)
        Kxx = self.kernel(xs, xs).diagonal()  # prior variance per xs

        # mu = Kx^T (K + sigma^2 I)^{-1} y
        alpha = self._K_inv.dot(self.y)
        mu = Kx.T.dot(alpha)

        # v = (K + sigma^2 I)^{-1} Kx
        v = self._K_inv.dot(Kx)  # shape (n_train, n_query)
        var = Kxx - np.sum(Kx * v, axis=0)
        var = np.maximum(var, 1e-12)
        std = np.sqrt(var)
        return mu, std

class LinUCB:
    def __init__(self, K = 4, dim = 5, lamb = 1, beta = 0.01):
        self.K = K
        self.dim = dim
        self.lamb = lamb
        self.beta = beta
        self.round = 0
        self.A = self.lamb * np.identity(self.dim)
        self.b = np.zeros((self.dim, 1))

    def predict(self, context):
        A_inv = np.linalg.inv(self.A)
        theta_est = A_inv.dot(self.b)
        mu_est = [theta_est.T.dot(context[i]) for i in range(self.K)]
        bonus = [np.sqrt((context[i].T.dot(A_inv)).dot(context[i])) for i in range(self.K)]
        # ucb = np.array([np.array(mu_est[i] + self.beta * bonus[i]) for i in range(self.K)],
        #                dtype=np.float32)
        ucb = np.array([mu_est[i] - self.beta * bonus[i] for i in range(self.K)],
                       dtype=np.float32)
        # ucbb = mu_est + self.beta * bonus
        return ucb.flatten()

    def update(self, context, reward):
        self.A += np.outer(context, context)
        self.b += reward * context.reshape(-1, 1)

    def reset(self):
        self.A = self.lamb * np.identity(self.dim)
        self.b = np.zeros((self.dim, 1))

class gb5:
    def __init__(self, X_train, y_train, depth = 5, update_time = 10):
        self.model = xgb.XGBRegressor(max_depth=depth)
        self.X_train = X_train
        self.y_train = y_train
        self.update_time = update_time
        self.time_cnt = 0
        self.pretrain()

    def pretrain(self):
        self.model.fit(self.X_train, self.y_train)

    def predict(self, context):
        # print('type of context:', type(context))
        return self.model.predict(context)

    def update(self, X, y):
        if type(X) is not np.ndarray:
            X = np.array(X, dtype=float)
        if type(y) is not np.ndarray:
            y = np.array(y, dtype=float)
        self.X_train = np.vstack((self.X_train, X.reshape(1, -1)))
        self.y_train = np.vstack((self.y_train, y.reshape(1, -1)))
        self.time_cnt += 1
        if self.time_cnt == self.update_time:
            self.time_cnt = 0
            self.pretrain()

    def reset(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        
class cost_LCB:
    def __init__(self, K):
        self.estimate = np.zeros(K)
        self.round = np.zeros(K)
        self.t = 1

    def predict(self, context):
        # print(np.maximum(self.round, np.ones_like(self.round)))
        s = np.log(self.t)
        result = self.estimate - np.sqrt(s / np.maximum(self.round, np.ones_like(self.round)))
        
        # print("cost estimate:", result)
        return result

    def update(self, action, value):
        self.t += 1
        self.round[action] += 1
        self.estimate[action] = self.estimate[action] * (self.round[action] - 1)/self.round[action] + value/self.round[action]

    def reset(self):
        self.estimate[:] = 0
        self.round[:] = 0
        self.t = 1
    
def save_results(filepath, filename, results):
    import pickle
    import os
    
    if not os.path.exists(filepath):
        os.makedirs(filepath, exist_ok=True)
    
    with open(os.path.join(filepath, filename), 'wb') as f:
        pickle.dump(results, f)
    
    print(f"Results saved to {os.path.join(filepath, filename)}")
