import numpy as np
from numpy.linalg import inv, cholesky, solve
from scipy.stats import norm, uniform


def efficient_matrix_inv(invcov, x):
    """
    Rank-1 update: (A + x x^T)^{-1} = A^{-1} - (A^{-1} x x^T A^{-1}) / (1 + x^T A^{-1} x)
    """
    invcov_x = invcov @ x
    return invcov - np.outer(invcov_x, invcov_x) / (1 + x @ invcov_x)


class LinBanditAlg:
    def __init__(self, env, n, params):
        self.env = env
        self.rng = env.rng
        self.K = env.K
        self.d = env.d
        self.n = n

        # Hyperparameters
        self.forced_exploration = params.get('forced_exploration', 0)
        self.sigma = params.get('sigma', 1.0)      # R: sub-Gaussian noise bound
        self.crs = params.get('crs', 1.0)
        self.delta = params.get('delta', 0.01)
        self.lam0 = params.get('lambda0', 1.0)    # λ
        self.L = params.get('L', 1.0)            # L: max ||x|| bound
        self.S = params.get('S', 1.0)            # S: max ||θ★|| bound
        self.beta_flag = params.get('beta', True)

        # Initialize Gram matrix and B vector (ridge)
        self.Gram = self.lam0 * np.eye(self.d)
        self.B = np.zeros(self.d)

    def get_arms(self, X):
        self.X = X

    def update(self, t, arm_idx, reward):
        x = self.X[arm_idx]
        self.Gram += np.outer(x, x) / (self.sigma ** 2)
        self.B += x * reward / (self.sigma ** 2)


class LinUCB(LinBanditAlg):
    @staticmethod
    def print():
        return "LinUCB"

    def get_arm(self, t):
        # Forced exploration for first d rounds
        if t < self.forced_exploration:
            return int(self.rng.integers(0, self.K))

        # Parameter estimate
        Gram_inv = inv(self.Gram)
        theta = Gram_inv.dot(self.B)

        # Beta width using theoretical bound
        # β_t = R * sqrt(d * ln(1 + (t * L^2) / λ) + 2 ln(1/δ)) + sqrt(λ) * S
        term1 = self.d * np.log(1 + (t * self.L**2) / self.lam0)
        term2 = 2 * np.log(1 / self.delta)
        beta = self.sigma * np.sqrt(term1 + term2) + np.sqrt(self.lam0) * self.S

        # UCB score
        inv_sqrt = np.sqrt(np.sum(self.X.dot(Gram_inv) * self.X, axis=1))
        scores = self.X.dot(theta) + beta * inv_sqrt
        return int(np.argmax(scores))


class Greedy(LinBanditAlg):
    @staticmethod
    def print():
        return "greedy"

    def get_arm(self, t):
        if t < self.forced_exploration:
            return int(self.rng.integers(0, self.K))
        theta = solve(self.Gram, self.B)
        return int(np.argmax(self.X.dot(theta)))


class EpsilonGreedy(LinBanditAlg):
    @staticmethod
    def print():
        return "e-greedy"

    def get_arm(self, t):
        # Forced or randomized exploration
        if t < self.forced_exploration or self.rng.random() < 0.05 * np.sqrt(self.n / (t + 1)) / 2:
            return int(self.rng.integers(0, self.K))
        theta = solve(self.Gram, self.B)
        return int(np.argmax(self.X.dot(theta)))


class LinTS(LinBanditAlg):
    @staticmethod
    def print():
        return "LinTS"

    def get_arm(self, t):
        if t < self.forced_exploration:
            return int(self.rng.integers(0, self.K))
        # Parameter estimate
        theta_hat = solve(self.Gram, self.B)
        # Compute theoretical beta width
        term1 = self.d * np.log(1 + (t * self.L**2) / self.lam0)
        term2 = 2 * np.log(1 / self.delta)
        beta = self.sigma * np.sqrt(term1 + term2) + np.sqrt(self.lam0) * self.S
        # Determine scaling
        scale = self.crs * (beta if self.beta_flag else 1.0)
        # Sample noise
        Lmat = cholesky(self.Gram)
        z = self.rng.standard_normal(self.d)
        theta = theta_hat + scale * solve(Lmat.T, z)
        return int(np.argmax(self.X.dot(theta)))


class RandLinUCB(LinBanditAlg):
    @staticmethod
    def print():
        return "RandLinUCB"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)
        # Perturbation hyperparameters
        self.M = params.get('M', 20)
        self.pdist = params.get('pdist', 'normal')
        self.pnormal_std = params.get('pnormal_std', 0.125)
        self.is_optimistic = params.get('is_optimistic', True)
        self.is_coupled = params.get('is_coupled', True)

    def get_arm(self, t):
        # if t < self.forced_exploration:
        #     return int(self.rng.integers(0, self.K))

        # Parameter estimate
        Gram_inv = inv(self.Gram)
        theta = Gram_inv.dot(self.B)

        # Discretize confidence multiplier
        x = np.linspace(-1, 1, self.M)
        if self.pdist.lower() == 'normal':
            probs = norm.pdf(x, 0, self.pnormal_std)
        else:
            probs = uniform.pdf(x, -1, 2)
        probs = probs / probs.sum()

        # Sample perturbation multiplier a
        if self.is_coupled:
            m = self.rng.choice(self.M, p=probs)
            a = x[m]
        else:
            m = self.rng.choice(self.M, size=self.K, p=probs)
            a = x[m]

        # Beta width
        term1 = self.d * np.log(1 + (t * self.L**2) / self.lam0)
        term2 = 2 * np.log(1 / self.delta)
        beta = self.sigma * np.sqrt(term1 + term2) + np.sqrt(self.lam0) * self.S

        # Perturbed UCB score
        inv_sqrt = np.sqrt(np.sum(self.X.dot(Gram_inv) * self.X, axis=1))
        scores = self.X.dot(theta) + a * beta * inv_sqrt
        return int(np.argmax(scores))


class LinPHE(LinBanditAlg):
    @staticmethod
    def print():
        return "LinPHE"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)
        self.a = params.get('a', 1.0)

    def get_arm(self, t):
        # if t < self.forced_exploration:
        #     return int(self.rng.integers(0, self.K))

        # Pseudo-reward perturbation via Gaussian
        theta_hat = solve(self.Gram, self.B)
        Lmat = cholesky(self.Gram)
        z = self.rng.standard_normal(self.d)
        theta_tilde = theta_hat + self.a * solve(Lmat.T, z)
        return int(np.argmax(self.X.dot(theta_tilde)))


class LinFP(LinBanditAlg):
    @staticmethod
    def print():
        return "LinFP"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)

    def get_arm(self, t):
        if t < self.forced_exploration:
            return int(self.rng.integers(0, self.K))
        Gram_inv = inv(self.Gram)
        theta = Gram_inv.dot(self.B)

        term1 = self.d * np.log(1 + (t * self.L**2) / self.lam0)
        term2 = 2 * np.log(1 / self.delta)
        beta = self.sigma * np.sqrt(term1 + term2) + np.sqrt(self.lam0) * self.S
        scale = self.crs * (beta if self.beta_flag else 1.0)

        inv_sqrt = np.sqrt(np.sum(self.X.dot(Gram_inv) * self.X, axis=1))
        eta = self.rng.standard_normal((self.d,))
        tilde_X = self.X + (inv_sqrt * scale)[:, None] * eta
        mu = tilde_X.dot(theta)
        return int(np.argmax(mu))

class FGTS:
    def __init__(self, env, n, params):
        self.env = env
        self.X = None  # Context vectors for arms, will be set by get_arms
        self.K = env.K  # Number of arms (or actions)
        self.d = env.d  # Dimension of the parameter space
        self.n = n      # Total number of time steps (episodes), T in paper

        self.X_buffer = []  # Stores context vectors of chosen arms
        self.r_buffer = []  # Stores observed rewards

        self.eta = params.get('eta', 1.0) # Set eta = 1 in this example [cite: 353]
        self.lam = params.get('lam', 0.0) # Feel-Good exploration term parameter, varies [cite: 352]
        self.b = params.get('b', float('inf')) # Set b = infinity in this example [cite: 354]
        self.delta = params.get('delta', 0.01) # Fixed learning rate [cite: 359]
        self.rho = params.get('rho', 100.0) # Set rho = 100 [cite: 350]
        self.rng = np.random.default_rng(params.get('seed', 42))
        self.theta = self.rng.normal(0, np.sqrt(1.0 / self.rho), self.d)

    def update(self, t, arm, reward):
        x = self.X[arm]
        self.X_buffer.append(x)
        self.r_buffer.append(reward)

    def get_arms(self, X):
        self.X = X

    def sgld_update(self, t):
        if t == 0:
            return
        
        rep = min(100, t)
        for _ in range(rep):
            i = self.rng.integers(0, t) # Data points from 0 to t-1
            x_i = self.X_buffer[i]
            r_i = self.r_buffer[i]

            f_theta_xa_values = self.X @ self.theta
            x_tstar = self.X[np.argmax(f_theta_xa_values)]

            grad_likelihood_term = 2 * self.eta * (x_i @ self.theta - r_i) * x_i
            grad_feelgood_term = -self.lam * x_tstar 
            total_grad = grad_likelihood_term + grad_feelgood_term

            log_prior_value = -self.rho * self.theta
            noise = self.rng.normal(0, 1, self.d)
            self.theta = self.theta - self.delta * (total_grad - (1/t) * log_prior_value) + np.sqrt(2 * self.delta / t) * noise

    def get_arm(self, t):
        self.sgld_update(t)
        mu_values = self.X @ self.theta

        return int(np.argmax(mu_values))

    @staticmethod
    def print():
        return "FGTS"
    
class OPAS_FGP:
    def __init__(self, env, n, params):
        self.env = env
        self.X = None
        self.K = env.K
        self.d = env.d
        self.n = n

        # Phase 1: Model estimation
        self.lam0 = params.get('lambda0', 1.0)
        self.sigma_model = params.get('sigma', 1.0)
        self.Gram = self.lam0 * np.eye(self.d)
        self.B = np.zeros(self.d)
        self.theta_hat = np.zeros(self.d)

        # Phase 2: Arm-wise perturbation
        self.eta = params.get('eta', 1.0)
        self.lam = params.get('lam', 0.01)
        self.b = params.get('b', float('inf'))
        self.rho = params.get('rho', 1.0)
        self.rng = np.random.default_rng(params.get('seed', 42))

        # OPAS specific parameters
        self.optim_lr = params.get('optim_lr', 0.01)
        self.optim_steps = params.get('optim_steps', 20) # SGD 반복 횟수
        self.sigma_pert = params.get('sigma_pert', 0.1)

        # --- KEY CHANGE 1: History now must store context 'x' for SGD ---
        self.arm_histories = {i: [] for i in range(self.K)}

    def get_arms(self, X):
        self.X = X

    def update(self, t, arm, reward):
        x = self.X[arm]
        self.Gram += np.outer(x, x) / (self.sigma_model ** 2)
        self.B += x * reward / (self.sigma_model ** 2)
        # Store both context and reward for this arm
        self.arm_histories[arm].append({'x': x, 'r': reward})

    def estimate_theta_hat(self):
        Gram_inv = inv(self.Gram)
        self.theta_hat = Gram_inv.dot(self.B)

    # --- KEY CHANGE 2: Optimization now uses SGD ---
    def find_optimal_zeta_for_arm(self, arm_idx):
        history = self.arm_histories[arm_idx]
        if not history:
            return np.zeros(self.d)

        mu_zeta = np.zeros(self.d)

        # The number of optimization steps is now fixed, not dependent on history size.
        for _ in range(self.optim_steps):
            # Randomly sample ONE data point from the arm's history
            past_data = history[self.rng.integers(0, len(history))]
            x_i = past_data['x']
            r_i = past_data['r']

            # Calculate STOCHASTIC gradient based on the single sampled point
            f_val_perturbed = (x_i + mu_zeta) @ self.theta_hat
            stochastic_grad_L = 2 * self.eta * (f_val_perturbed - r_i) * self.theta_hat

            # Feel-good term's gradient is based on the arm's own context
            # This is an arm-specific optimistic direction
            grad_FG = 0
            if (x_i + mu_zeta) @ self.theta_hat < self.b:
                 grad_FG = -self.lam * self.theta_hat

            # Prior gradient
            grad_prior = self.rho * mu_zeta

            # Update mu_zeta using the stochastic gradient
            total_stochastic_grad = stochastic_grad_L + grad_FG + grad_prior
            mu_zeta -= self.optim_lr * total_stochastic_grad

        return mu_zeta

    def get_arm(self, t):
        self.estimate_theta_hat()
        perturbed_arm_values = np.zeros(self.K)

        for i in range(self.K):
            mu_ti = self.find_optimal_zeta_for_arm(i)
            # Simplified covariance, decaying with the number of times the arm was pulled
            current_sigma = self.sigma_pert / np.sqrt(1 + len(self.arm_histories[i]))
            zeta_ti = self.rng.normal(mu_ti, current_sigma) # Now faster

            perturbed_arm_values[i] = (self.X[i] + zeta_ti) @ self.theta_hat

        return int(np.argmax(perturbed_arm_values))

    @staticmethod
    def print():
        return "OPAS-FGP"
