import numpy as np
from src.bandits.NonStationaryBandit import NonStationaryBandit
from src.utils import sigmoid

class NSLinearBanditBall(NonStationaryBandit):

    def __init__(
        self,
        num_actions,
        noise_variance,
        d,
        theta,
        mean_low,
        mean_high,
        reward_bound,
        theta_bound=1,
        action_bound=1,
        continuous=False,
        radius=0.1,
        prob=0.1
    ):
        self.d = d  
        assert len(theta) == d, f"Error: Theta should have dimension {self.d}"
        assert np.linalg.norm(theta) <= theta_bound, "Error: Theta should satisfy the norm constraint"
        self.theta_bound = theta_bound
        self.action_bound = action_bound
        self.theta = np.array(theta, dtype=np.float64)  
        self.target_theta = None 
        



        super().__init__(
            arm_type="normal",
            num_actions=num_actions,
            noise_variance=noise_variance,
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            continuous=continuous
        )

        
        self.cyclic=False
    
        if self.continuous:
            self.theta = np.random.uniform(-1, 1, self.d)

            if np.linalg.norm(self.theta)>self.theta_bound:
                self.theta /= np.linalg.norm(self.theta)
                self.theta *= self.theta_bound

        
        self.init_theta=self.theta.copy()

        self.radius=radius

        self.prob=prob

        self.init_params = {
            'num_actions': num_actions,
            'noise_variance': noise_variance,
            'd': d,
            'theta':theta,
            'mean_low': mean_low,
            'mean_high': mean_high,
            'reward_bound':reward_bound,
            'theta_bound':theta_bound,
            'action_bound':action_bound,
            'continuous':continuous,
            'radius':radius,
            'prob':prob
        }

    def set_arms(self):
    
        arms = []
        for _ in range(self.num_actions):
            arm_vector = np.random.randn(self.d)
            if np.linalg.norm(arm_vector)>self.action_bound:
                arm_vector /= np.linalg.norm(arm_vector)
                arm_vector *= self.action_bound
            arms.append(arm_vector)
        return arms

    def get_mean_reward(self, action):

        return np.dot(self.arms[action], self.theta)

    def abrupt_change(self, new_theta=None):

        if new_theta is not None:
            assert len(new_theta) == self.d, f"Error: New theta should have dimension {self.d}"
            assert np.linalg.norm(new_theta) <= self.theta_bound, "Error: New theta should satisfy the norm constraint"
            self.theta = np.array(new_theta, dtype=np.float64)
        else:
            if not self.cyclic:
                self.theta=np.random.uniform(0, 1, self.d)
                if np.linalg.norm(self.theta)>self.theta_bound:
                    self.theta /= np.linalg.norm(self.theta)
                    self.theta *= self.theta_bound
            else:
                
                theta=np.zeros(self.d)
                theta[np.random.choice(np.arange(self.d),size=np.random.randint(1,self.d+1),replace=False)]=1
                self.theta=theta.copy()
                if np.linalg.norm(self.theta)>self.theta_bound:
                    self.theta /= np.linalg.norm(self.theta)
                    self.theta *= self.theta_bound
                
             

        self.reward_means = self.get_reward_means()
        self.target_theta = None



    def get_P_T(self,T):
        
        res = 0
        theta = self.init_theta.copy()
        for t in range(1, T):
            temp = (1 - t/T) * self.init_theta + t/T * self.target_theta
            res += np.linalg.norm(temp - theta)
            theta = temp
        return res


    
    def sample_y_in_ball_with_constraint(self,x, r, L):
        d = x.shape[0] 

        direction = np.random.normal(size=d)
        direction /= np.linalg.norm(direction)

        u = np.random.random()
        random_radius = r * (u ** (1.0 / d))

        y = x + random_radius * direction

        norm_y = np.linalg.norm(y)
        if norm_y > L:
            y = (y / norm_y) * L

        return y

    def gradual_change(self, change_rate=0.01):
        if np.random.random()<self.prob:
            
            self.theta=np.random.uniform(0, 1, self.d)
            if np.linalg.norm(self.theta)>self.theta_bound:
                self.theta /= np.linalg.norm(self.theta)
                self.theta *= self.theta_bound
        else:
            self.theta=self.sample_y_in_ball_with_constraint(self.theta,self.radius,self.theta_bound)

        
        self.reward_means = self.get_reward_means()
        

    def re_init(self):

        super().re_init()
        self.target_theta = None
