import numpy as np
from src.bandits.NonStationaryBandit import NonStationaryBandit
from src.utils import sigmoid

class NSGenLinearBandit(NonStationaryBandit):

    def __init__(
        self,
        num_actions,
        noise_variance,
        d,
        theta,
        mean_low,
        mean_high,
        reward_bound,
        theta_bound=1,
        action_bound=1,
        continuous=False
    ):
        self.d = d 
        assert len(theta) == d, f"Error: Theta should have dimension {self.d}"
        assert np.linalg.norm(theta) <= theta_bound, "Error: Theta should satisfy the norm constraint"
        self.theta_bound = theta_bound
        self.action_bound = action_bound
        self.theta = np.array(theta, dtype=np.float64)  
        self.target_theta = None 
        super().__init__(
            arm_type="normal",
            num_actions=num_actions,
            noise_variance=noise_variance,
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            continuous=continuous
        )

        self.cyclic=False
        if self.continuous:
            self.theta = np.random.uniform(-1, 1, self.d)

            if np.linalg.norm(self.theta)>self.theta_bound:
                self.theta /= np.linalg.norm(self.theta)
                self.theta *= self.theta_bound

        
        self.init_theta=self.theta.copy()

        

        self.init_params = {
            'num_actions': num_actions,
            'noise_variance': noise_variance,
            'd': d,
            'theta':theta,
            'mean_low': mean_low,
            'mean_high': mean_high,
            'reward_bound':reward_bound,
            'theta_bound':theta_bound,
            'action_bound':action_bound,
            'continuous':continuous
        }

    def set_arms(self):
        arms = []
        for _ in range(self.num_actions):
            arm_vector = np.random.randn(self.d)
            arm_vector /= np.linalg.norm(arm_vector)
            arm_vector *= np.sqrt(self.action_bound)
            arms.append(arm_vector)
        return arms

    def get_mean_reward(self, action):
        return sigmoid(np.dot(self.arms[action], self.theta))

    def abrupt_change(self, new_theta=None):
        if new_theta is not None:
            assert len(new_theta) == self.d, f"Error: New theta should have dimension {self.d}"
            assert np.linalg.norm(new_theta) <= self.theta_bound, "Error: New theta should satisfy the norm constraint"
            self.theta = np.array(new_theta, dtype=np.float64)
        else:
            if not self.cyclic:
                self.theta = np.random.uniform(-self.theta_bound, self.theta_bound, self.d)

                if np.linalg.norm(self.theta)>self.theta_bound:
                    self.theta /= np.linalg.norm(self.theta)
                    self.theta *= self.theta_bound
                
                
            else:
                theta=np.zeros(self.d)
                theta[np.random.choice(np.arange(self.d),size=np.random.randint(1,self.d+1),replace=False)]=self.theta_bound
                self.theta=theta.copy()
                if np.linalg.norm(self.theta)>self.theta_bound:
                    self.theta /= np.linalg.norm(self.theta)
                    self.theta *= self.theta_bound
        self.reward_means = self.get_reward_means()
        self.target_theta = None


    def get_P_T(self,T):
        
        res = 0
        theta = self.init_theta.copy()
        for t in range(1, T):
            temp = (1 - t/T) * self.init_theta + t/T * self.target_theta
            res += np.linalg.norm(temp - theta)
            theta = temp
        return res
    def gradual_change(self, change_rate=0.01):

        if self.target_theta is None:

            if self.continuous=='AGGRESIVE':
                self.target_theta=self.optimal_x_end(self.init_theta,self.theta_bound)
            else:
                self.target_theta = np.random.uniform(-1, 1, self.d)
                while(np.linalg.norm(self.theta - self.target_theta) < 1e-3):
                    self.target_theta = np.random.uniform(-1, 1, self.d)
                    if np.linalg.norm(self.target_theta)>self.theta_bound:
                        self.target_theta /= np.linalg.norm(self.target_theta)
                        self.target_theta *= self.theta_bound

        self.theta = (1 - change_rate) * self.init_theta + change_rate * self.target_theta

        if np.linalg.norm(self.theta)>self.theta_bound:
            self.theta /= np.linalg.norm(self.theta)
            self.theta *= self.theta_bound
    
        self.reward_means = self.get_reward_means()

    def re_init(self):
        super().re_init()
        self.target_theta = None
