import numpy as np
from math import log
from numpy.linalg import pinv
from src.utils import sigmoid, dsigmoid
from scipy.optimize import minimize, NonlinearConstraint
from src.core.bandit import BanditAlgorithm

class SCB_WeightUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu):
        super().__init__(num_actions, horizon)

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        self.horizon = horizon

        self.init_params = {
            'num_actions': num_actions,
            'horizon': horizon,
            'd': d,
            'delta': delta,
            'r_lambda': r_lambda,
            'S': S,
            'L': L,
            'R': R,
            'k_mu': k_mu,
            'c_mu': c_mu
        }

        self.c_1 = np.sqrt(self.r_lambda * self.c_mu) * (self.S + 0.5)
        self.c_2 = (2 / np.sqrt(self.r_lambda * self.c_mu)) * log(1 / self.delta)

        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1 / self.r_lambda) * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.arms = []
        self.rewards = []
        self.gamma_2t = 1
        self.w = None

    def select_arm(self, arms, change_points):
        arms_array = np.array(arms)

        if isinstance(change_points, list):
            PT = len(change_points)
        else:
            PT = change_points

        self.all_arms = arms_array
        self.gamma = 1 - max(1 / self.T, (PT / (self.d * self.T)) ** 0.5)

        x_num = len(arms)
        beta_t = self.c_1 + self.c_2 + (self.d / np.sqrt(self.r_lambda * self.c_mu)) * np.log(
            4 * (1 + (self.k_mu * (1 - self.gamma_2t)) / (self.r_lambda * self.c_mu * self.d * (1 - self.gamma**2)))
        )

        dot_products = arms_array @ self.theta_tilde
        arms_inv_V = arms_array @ self.inv_V
        quadratic_forms = np.sum(arms_array * arms_inv_V, axis=1)
        sqrt_terms = np.sqrt(quadratic_forms)
        ucb_s = sigmoid(dot_products) + (np.sqrt(4 + 8 * self.S) * self.k_mu / np.sqrt(self.c_mu)) * beta_t * sqrt_terms

        mixer = np.random.random(ucb_s.size)
        ucb_indices = np.lexsort((mixer, ucb_s))
        chosen_arm = ucb_indices[-1]

        return chosen_arm

    def update_statistics(self, x, y):
        x_arm = self.all_arms[x]
        self.arms.append(x_arm)
        self.rewards.append(y)
        aat = np.outer(x_arm, x_arm)
        self.gamma_2t *= self.gamma ** 2
        self.V = self.gamma * self.V + aat + (1 - self.gamma) * self.r_lambda * np.identity(self.d)
        self.inv_V = pinv(self.V)

        self.estimator()

    def re_init(self):
        super().re_init()
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1 / self.r_lambda) * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.arms = []
        self.rewards = []
        self.gamma_2t = 1
        self.w = None

    def estimator(self):
        num = len(self.rewards)
        if num > 0:
            self.w = np.logspace(num - 1, 0, num, base=self.gamma)[:, None]
        else:
            self.w = np.empty((0, 1))

        if self.ctr % self.lazy_update_fr == 0 or num < 200:
            arms_array = np.array(self.arms)
            rewards_array = np.array(self.rewards)
            theta_hat = self.theta_hat
            hessian = self.H

            for _ in range(5):
                coeffs = sigmoid(arms_array @ theta_hat)[:, None]
                y = coeffs - rewards_array[:, None]
                grad = self.r_lambda * self.c_mu * theta_hat + np.sum(self.w * y * arms_array, axis=0)
                hessian = (arms_array.T @ ((self.w * coeffs * (1 - coeffs)) * arms_array)) + self.r_lambda * self.c_mu * np.eye(self.d)
                theta_hat -= np.linalg.solve(hessian, grad)

            self.theta_hat = theta_hat
            self.H = hessian

        self.ctr += 1

        if np.linalg.norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection_new()

    def gt(self, theta, arms):
        coeffs = sigmoid(arms @ theta)[:, None]
        gt = np.sum(self.w * arms * coeffs, axis=0) + self.r_lambda * self.c_mu * theta
        return gt

    def hessian(self, theta, arms):
        coeffs = dsigmoid(arms @ theta)[:, None]
        ht = arms.T @ (self.w * coeffs * arms) + self.r_lambda * self.c_mu * np.eye(self.d)
        return ht

    def proj_fun(self, theta, arms):
        H = self.hessian(theta, arms)
        inv_H = pinv(H)
        diff_gt = self.gt(theta, arms) - self.gt(self.theta_hat, arms)
        return diff_gt @ (inv_H @ diff_gt)

    def projection(self, arms):
        fun = lambda t: self.proj_fun(t, arms)
        norm = lambda t: np.linalg.norm(t)
        constraint = NonlinearConstraint(norm, 0, self.S)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraint)
        return opt.x

    def projection_new(self):
        norm_theta = np.linalg.norm(self.theta_hat)
        if norm_theta <= self.S:
            return self.theta_hat
        else:
            return (self.theta_hat / norm_theta) * self.S

    def __str__(self):
        return 'SCB-WeightUCB'
