import numpy as np
from math import log
from numpy.linalg import pinv
from src.utils import sigmoid, dsigmoid
from scipy.optimize import minimize, NonlinearConstraint

class Base_SCB(object):
    def __init__(self, d, delta, r_lambda, S, L, R, k_mu, c_mu, s, e):

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        self.s = s
        self.e = e
        self.len = self.e - self.s + 1

        self.c_1 = np.sqrt(self.r_lambda * self.c_mu) * (self.S + 0.5)
        self.c_2 = (2 / np.sqrt(self.r_lambda * self.c_mu)) * log(1 / self.delta)
        self.inv_r_lambda_c_mu = 1 / (self.r_lambda * self.c_mu)

        self.t = 0
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1 / self.r_lambda) * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.arms = []     
        self.rewards = [] 

    def select_arm(self, arms):

        assert isinstance(arms, list), 'List of arms as input required'

        arms_array = np.array(arms)
        x_num = len(arms)
        ucb_s = np.zeros(x_num)

        beta_t = self.c_1 + self.c_2 + (self.d / np.sqrt(self.r_lambda * self.c_mu)) * np.log(
            4 * (1 + ((self.k_mu * self.t) / (self.r_lambda * self.c_mu * self.d)))
        )

        dot_products = arms_array @ self.theta_tilde
        arms_inv_V = arms_array @ self.inv_V
        quadratic_forms = np.sum(arms_array * arms_inv_V, axis=1)
        sqrt_terms = np.sqrt(quadratic_forms)
        ucb_s = sigmoid(dot_products) + (np.sqrt(4 + 8 * self.S) * self.k_mu / np.sqrt(self.c_mu)) * beta_t * sqrt_terms

        mixer = np.random.random(ucb_s.size)
        ucb_indices = np.lexsort((mixer, ucb_s))
        chosen_arm = ucb_indices[-1]
        max_ub = ucb_s[chosen_arm]

        return chosen_arm, max_ub, self.theta_hat

    def update_state(self, x, y):

        assert isinstance(x, np.ndarray), 'np.array required'
       

        self.arms.append(x)
        self.rewards.append(y)
        aat = np.outer(x, x)
        self.V += aat
        self.inv_V = pinv(self.V)

        self.estimator()
        self.t += 1

    def re_init(self):

        self.t = 0
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1 / self.r_lambda) * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.arms = []
        self.rewards = []

    def estimator(self):
        if self.ctr % self.lazy_update_fr == 0 or len(self.rewards) < 200:
            arms_array = np.array(self.arms)
            rewards_array = np.array(self.rewards)
            theta_hat = self.theta_hat
            hessian = self.H

            for _ in range(5):
                coeffs = sigmoid(arms_array @ theta_hat)[:, None]
                y = coeffs - rewards_array[:, None]
                grad = self.r_lambda * self.c_mu * theta_hat + np.sum(y * arms_array, axis=0)
                hessian = arms_array.T @ (coeffs * (1 - coeffs) * arms_array) + self.r_lambda * self.c_mu * np.eye(self.d)
                theta_hat -= np.linalg.solve(hessian, grad)

            self.theta_hat = theta_hat
            self.H = hessian
        self.ctr += 1

        if np.linalg.norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection_new()
            


    def gt(self, theta, arms):
        coeffs = sigmoid(arms @ theta)[:, None]
        gt = np.sum(arms * coeffs, axis=0) + self.r_lambda * self.c_mu * theta
        return gt

    def hessian(self, theta, arms):
        coeffs = dsigmoid(arms @ theta)[:, None]
        ht = arms.T @ (coeffs * arms) + self.r_lambda * self.c_mu * np.eye(self.d)
        return ht

    def proj_fun(self, theta, arms):
        H = self.hessian(theta, arms)
        inv_H = pinv(H)
        diff_gt = self.gt(theta, arms) - self.gt(self.theta_hat, arms)
        return diff_gt @ (inv_H @ diff_gt)

    def projection(self, arms):
        fun = lambda t: self.proj_fun(t, arms)
        norm = lambda t: np.linalg.norm(t)
        constraint = NonlinearConstraint(norm, 0, self.S)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraint)
        return opt.x

    def projection_new(self):
        norm_theta = np.linalg.norm(self.theta_hat)
        if norm_theta <= self.S:
            return self.theta_hat
        else:
            return (self.theta_hat / norm_theta) * self.S

    def __str__(self):
        return 'Base-SCB'
