from abc import ABC, abstractmethod
import numpy as np
from scipy.stats import truncnorm


class NonStationaryBandit(ABC):
    def __init__(self, arm_type, num_actions, noise_variance, mean_low, mean_high, reward_bound,continuous):
        self.reward_bound=reward_bound
        self.mean_low = mean_low
        self.continuous=continuous
        self.mean_high = mean_high
        self.num_actions = num_actions
        self.noise_variance = noise_variance  
        self.arm_type = arm_type
        self.arms = self.set_arms()
        self.reward_means = self.get_reward_means()  



    def optimal_x_end(self,x_begin, S):
        if np.linalg.norm(x_begin) == 0:
            return S * np.random.randn(*x_begin.shape) 
        return -S * x_begin / np.linalg.norm(x_begin)
    @abstractmethod
    def set_arms(self):
        pass

    @abstractmethod
    def get_mean_reward(self, action):
        pass

    def get_reward_means(self):
        return [self.get_mean_reward(action) for action in range(self.num_actions)]

    def get_reward(self, action):
        deterministic_reward = self.reward_means[action]
        if self.arm_type == 'bernoulli':
            return int(np.random.random()<deterministic_reward)
        elif self.arm_type == 'normal':
            return np.random.normal(deterministic_reward, np.sqrt(self.noise_variance))

    @abstractmethod
    def abrupt_change(self):
        pass

    @abstractmethod
    def gradual_change(self):
        pass

    def get_best_arm(self):
        current_rewards = self.reward_means
        best_arm = np.argmax(current_rewards)
        best_reward = current_rewards[best_arm]
        return best_arm, best_reward

    def re_init(self):
        self.arms = self.set_arms()
        self.reward_means = self.get_reward_means()
