import numpy as np
from math import log
from numpy.linalg import pinv
from src.utils import sigmoid, dsigmoid
from scipy.optimize import minimize, NonlinearConstraint


class Base_GLB(object):
    def __init__(self, d, delta, r_lambda, S, L, R, k_mu, c_mu, s, e):

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        self.s = s
        self.e = e
        self.len = self.e - self.s +1

        self.c_1 = np.sqrt(self.r_lambda) * c_mu * self.S
        self.c_2 = 2 * log(1 / self.delta)

        self.t = 0  
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)  
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = np.identity(self.d) / self.r_lambda
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)  
        self.arms = []  
        self.rewards = []  #

    def select_arm(self, arms):

        x_num = len(arms)  
        ucb_s = np.zeros(x_num)  
        beta_t = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * np.log(1 + self.t / (self.r_lambda * self.d)))
        for (i, x) in enumerate(arms):
            ucb_s[i] = sigmoid(np.dot(x, self.theta_tilde)) + (2 * self.k_mu / self.c_mu) * beta_t * np.sqrt(
                np.dot(x.T, np.dot(self.inv_V, x)))
        mixer = np.random.random(ucb_s.size)  
        ucb_indices = list(np.lexsort((mixer, ucb_s))) 
        output = ucb_indices[::-1] 
        chosen_arm = output[0]
        max_ub = ucb_s.max()
        return chosen_arm, max_ub, self.theta_hat

    def update_state(self, x, y):

        assert isinstance(x, np.ndarray), 'np.array required'
        self.arms.append(x)
        self.rewards.append(y)
        aat = np.outer(x, x.T)
        self.V = self.V + aat
        self.inv_V = pinv(self.V)
        self.estimator()
        self.t += 1

    def re_init(self):

        self.t = 0 
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d) 
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = np.identity(self.d) / self.r_lambda
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)  
        self.arms = [] 
        self.rewards = []  

    def estimator(self):
        if self.ctr % self.lazy_update_fr == 0 or len(self.rewards) < 200:
            theta_hat = self.theta_hat
            hessian = self.H
            for _ in range(5):
                coeffs = sigmoid(np.dot(self.arms, theta_hat)[:, None])
                y = coeffs - np.array(self.rewards)[:, None]
                grad = self.r_lambda * self.c_mu * theta_hat + np.sum(y * self.arms, axis=0)
                hessian = np.dot(np.array(self.arms).T,
                                 coeffs * (1 - coeffs) * np.array(self.arms)) + self.r_lambda * self.c_mu * np.eye(
                    self.d)
                theta_hat -= np.linalg.solve(hessian, grad)
            self.theta_hat = theta_hat
            self.H = hessian
        self.ctr += 1

        if np.linalg.norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection_new()

    def gt(self, theta, arms):
        coeffs = sigmoid(np.dot(arms, theta))[:, None]
        gt = np.sum(arms * coeffs, axis=0) + self.r_lambda * self.c_mu * theta
        return gt

    def hessian(self, theta, arms):
        coeffs = dsigmoid(np.dot(arms, theta))[:, None]
        ht = np.dot(np.array(arms).T, coeffs * arms) + self.r_lambda * self.c_mu * np.eye(self.d)
        return ht

    def proj_fun(self, theta, arms):

        diff_gt = self.gt(theta, arms) - self.gt(self.theta_hat, arms)
        fun = np.dot(diff_gt, np.dot(self.inv_V, diff_gt))
        return fun

    def proj_grad(self, theta, arms):

        diff_gt = self.gt(theta, arms) - self.gt(self.theta_hat, arms)
        grads = 2 * np.dot(self.inv_V, np.dot(self.hessian(theta, arms), diff_gt))
        return grads

    def projection(self, arms):
        fun = lambda t: self.proj_fun(t, arms)
        grads = lambda t: self.proj_grad(t, arms)
        norm = lambda t: np.linalg.norm(t)
        constraint = NonlinearConstraint(norm, 0, self.S)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', jac=grads, constraints=constraint)
        return opt.x
    
    def projection_new(self):

        norm_theta = np.linalg.norm(self.theta_hat)
        if norm_theta <= self.S:
            return self.theta_hat
        else:
            return (self.theta_hat / norm_theta) * self.S

    def __str__(self):
        return 'Base-GLB'
