import numpy as np
from src.bandits.NonStationaryBandit import NonStationaryBandit
from src.utils import sigmoid

class NSSCBBandit(NonStationaryBandit):
    def __init__(
        self,
        num_actions,
        noise_variance,
        d,
        theta,
        mean_low,
        mean_high,
        reward_bound,
        theta_bound=1,
        action_bound=1,
        continuous=False
    ):
        self.d = d 
        assert len(theta) == d, f"Error: Theta should have dimension {self.d}"
        assert np.linalg.norm(theta) <= theta_bound, "Error: Theta should satisfy the norm constraint"
        self.theta_bound = theta_bound
        self.action_bound = action_bound
        self.theta = np.array(theta, dtype=np.float64)  
        self.target_theta = None 
        super().__init__(
            arm_type="bernoulli",
            num_actions=num_actions,
            noise_variance=noise_variance,
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            continuous=continuous
        )

        self.cyclic=True
        if self.continuous:
            self.theta = np.random.uniform(-theta_bound, theta_bound, d)

            if np.linalg.norm(self.theta)>self.theta_bound:
                self.theta /= np.linalg.norm(self.theta)
                self.theta *= self.theta_bound

        
        self.init_theta=self.theta.copy()

        

        self.init_params = {
            'num_actions': num_actions,
            'noise_variance': noise_variance,
            'd': d,
            'theta':theta,
            'mean_low': mean_low,
            'mean_high': mean_high,
            'reward_bound':reward_bound,
            'theta_bound':theta_bound,
            'action_bound':action_bound,
        }
        self.best_history=[]


    def set_arms(self):
        arms = []
        for _ in range(self.num_actions):
            arm_vector = np.random.randn(self.d)
            arm_vector /= np.linalg.norm(arm_vector)
            arm_vector *= self.action_bound
            arms.append(arm_vector)
        return arms

    def get_mean_reward(self, action):
        return sigmoid(np.dot(self.arms[action], self.theta))


    def get_P_T(self,T):
        
        res = 0
        theta = self.init_theta.copy()
        for t in range(1, T):
            temp = (1 - t/T) * self.init_theta + t/T * self.target_theta
            res += np.linalg.norm(temp - theta)
            theta = temp
        return res
    def abrupt_change(self, new_theta=None):
        old_best, _ = self.get_best_arm()

        if new_theta is not None:
            candidate = np.asarray(new_theta, dtype=float)
        else:
            while True:
                candidate = np.zeros(self.d)
                k = np.random.randint(1, self.d + 1)
                idx = np.random.choice(self.d, size=k, replace=False)
                candidate[idx] = self.theta_bound
                candidate *= np.random.choice([-1, +1], size=self.d)
                candidate /= max(1., np.linalg.norm(candidate) / self.theta_bound)

                self.theta = candidate
                self.reward_means=self.get_reward_means()
                new_best, _ = self.get_best_arm()
                if new_best != old_best:
                    if len(self.best_history)<3:
                        self.best_history.append(new_best)
                        break
                    else:
                        if new_best!=self.best_history[-1] and new_best!=self.best_history[-2] and new_best!=self.best_history[-3]:
                            self.best_history.append(new_best)
                            break

        self.theta = candidate.copy()
        self.reward_means = self.get_reward_means()
        self.target_theta = None

    def gradual_change(self, change_rate=0.01):

        if self.target_theta is None:

            if self.continuous=='AGGRESIVE':
                self.target_theta=self.optimal_x_end(self.init_theta,self.theta_bound)
            else:
                self.target_theta = np.random.uniform(-1, 1, self.d)
                while(np.linalg.norm(self.theta - self.target_theta) < 1e-3):
                    self.target_theta = np.random.uniform(-1, 1, self.d)
                    if np.linalg.norm(self.target_theta)>self.theta_bound:
                        self.target_theta /= np.linalg.norm(self.target_theta)
                        self.target_theta *= self.theta_bound
     

        self.theta = (1 - change_rate) * self.init_theta + change_rate * self.target_theta
     
        if np.linalg.norm(self.theta)>self.theta_bound:
            self.theta /= np.linalg.norm(self.theta)
            self.theta *= self.theta_bound
        self.reward_means = self.get_reward_means()

    def re_init(self):
        super().re_init()
        self.target_theta = None
