import numpy as np

class ContextualBandit():
    
    def __init__(self, theta, n_arms, noise_sampler, context_sampler, context_generator= None):
        """ Class for contextual bandits
        For the moment, it is adapted to noise that is identically distributed across arms

        Parameters
        ----------
        theta : The preference vector
            _description_
        n_arms : number of arms
            _description_
        context_generator : either a stochastic sampler or a given generator
            _description_
        noise_sampler: a function to sample additive noise
        """
        self.theta = theta
        self.n_arms = n_arms
        self.noise_sampler = noise_sampler
        self.dim = len(theta)
        self.context_generator = context_generator
        self.context_sampler = context_sampler
    
    def create_context_generator(self, horizon):
        """ When a context generator is None, this function
        creates a context generator using the context sampler 

        Parameters
        ----------
        horizon : _type_
            _description_
        """
        def generator(horizon):
            for _ in range(horizon): 
                yield self.context_sampler((self.dim, self.n_arms))
        self.context_generator = generator(horizon)
    
    def create_noise_table(self, horizon, repetitions):
        self.noise_table = self.noise_sampler(size= (horizon, repetitions))
        return self

class MultiTaskContextualBandit():

    def __init__(self, Theta, n_arms, std, context_sampler, context_generator= 0):
        self.Theta = Theta # matrix 
        self.n_arms = n_arms
        self.n_users = len(Theta)
        self.dim = Theta.shape[1]

        self.context_generator = context_generator
        if isinstance(context_generator, int):  
            self.context_seed = context_generator
            self.context_generator = None
        
        self.context_sampler = context_sampler
        self.std = std
    
    def create_context_generator(self, horizon, rng):
        def generator(horizon):
            # random_generator = np.random.default_rng(self.context_seed)
            for t in range(horizon): 
                yield self.context_sampler(t, self.dim, self.n_arms, rng)
        self.context_generator = generator(horizon)
    
    def create_noise_table(self, horizon, rng):
        self.noise_table = self.std*rng.standard_normal(size= (horizon))
        return self