import numpy as np

class RandomMixingPolicy:
    def __init__(self, num_A, opt_a=None, cov=0.0):
        # generate the random policy
        p_1 = np.random.dirichlet(np.ones(num_A)) # randomly generate a policy following Dirichilet prior
        p_2 = np.zeros(num_A) # generate a point-mass (deterministic) policy with a randomly chosen action
       
        if cov > 0:
            p_2[opt_a] = 1
            w = cov
            # w = (np.random.choice(11)) / 10
        else:
            p_2[np.random.choice(num_A)] = 1
            w = (np.random.choice(11)) / 10 # randomly choose the mixing weight
        self.p = (1 - w) * p_1 + w * p_2 # mixing 
        self.num_A = num_A

    def act(self, As, Rs): # a policy takes history as input and take an action
        '''
        As: list of history actions
        Rs: list of history rewards
        '''
        a = np.random.choice(self.num_A, p=self.p)
        return a

    def update(self, *args, **kargs):
        '''
        Dummy function; will be removed if parent class is defined
        '''
        pass

    def batch_update(self,*args,**kargs):
        pass

class RandomOptimalPolicy:
    def __init__(self, num_A, opt_a):
        # generate the random policy
        p_1 = np.random.dirichlet(np.ones(num_A)) # randomly generate a policy following Dirichilet prior
        p_2 = np.zeros(num_A) # generate a point-mass (deterministic) policy with a randomly chosen action
        p_2[opt_a] = 1
        w = (np.random.choice(11)) / 10 # randomly choose the mixing weight
        self.p = (1 - w) * p_1 + w * p_2 # mixing 
        self.num_A = num_A

    def act(self, As, Rs): # a policy takes history as input and take an action
        '''
        As: list of history actions
        Rs: list of history rewards
        '''
        a = np.random.choice(self.num_A, p=self.p)
        return a

    def update(self, *args, **kargs):
        '''
        Dummy function; will be removed if parent class is defined
        '''
        pass

    def batch_update(self,*args,**kargs):
        pass

class UCBPolicy:
    def __init__(self,num_A, const=1.0):
        self.num_A = num_A
        self.counts = np.zeros(self.num_A)
        self.r_sums = np.zeros(self.num_A)
        self.const = const

    def act(self, As, Rs):
        b_mean = self.r_sums / np.maximum(1, self.counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(self.counts))
        bounds = b_mean + bons

        a = np.argmax(bounds)
        return a

    def update(self, chosen_arm, observed_reward):
        self.counts[chosen_arm] += 1
        self.r_sums[chosen_arm] += observed_reward

    def batch_update(self, As, Rs):
        # Reset
        self.counts = np.zeros(self.num_A)
        self.r_sums = np.zeros(self.num_A)
        # Use all data to update
        for i in range(len(As)):
            a = As[i]
            r = Rs[i]
            self.counts[a] += 1
            self.r_sums[a] += r

class LCBPolicy:
    def __init__(self,num_A, const=1.0):
        self.num_A = num_A
        self.counts = np.zeros(self.num_A)
        self.r_sums = np.zeros(self.num_A)
        self.const = const

    def act(self, As, Rs):
        b_mean = self.r_sums / np.maximum(1, self.counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(self.counts))
        bounds = b_mean - bons

        a = np.argmax(bounds)
        return a

    def update(self, chosen_arm, observed_reward):
        self.counts[chosen_arm] += 1
        self.r_sums[chosen_arm] += observed_reward

    def batch_update(self, As, Rs):
        # Reset
        self.counts = np.zeros(self.num_A)
        self.r_sums = np.zeros(self.num_A)
        # Use all data to update
        for i in range(len(As)):
            a = As[i]
            r = Rs[i]
            self.counts[a] += 1
            self.r_sums[a] += r

class EmpMeanPolicy:
    def __init__(self,num_A, const=1.0, online=False):
        self.num_A = num_A
        self.counts = np.zeros(self.num_A)
        self.r_sums = np.zeros(self.num_A)
        self.const = const
        self.online=False

    def act(self, As, Rs):
        b_mean = self.r_sums / np.maximum(1, self.counts)

        i = np.argmax(b_mean)
        j = np.argmin(self.counts)
        if self.online and self.counts[j] == 0:
            i = j

        return i

    def update(self, chosen_arm, observed_reward):
        self.counts[chosen_arm] += 1
        self.r_sums[chosen_arm] += observed_reward

    def batch_update(self, As, Rs):
        # Reset
        self.counts = np.zeros(self.num_A)
        self.r_sums = np.zeros(self.num_A)
        # Use all data to update
        for i in range(len(As)):
            a = As[i]
            r = Rs[i]
            self.counts[a] += 1
            self.r_sums[a] += r


class TompsonSamplingPolicy:
    def __init__(self,num_A):
        self.num_A = num_A
        self.variance = 0.3 ** 2
        self.prior_mean = 0
        self.prior_var = 1
        self.means = np.ones(self.num_A) * self.prior_mean
        self.variances = np.ones(self.num_A) * self.prior_var
        self.counts = np.zeros(self.num_A)
        self.arm_rewards = [[] for _ in range(self.num_A)]

    def act(self,As,Rs):
        sampled_rewards = np.random.normal(self.means, np.sqrt(self.variances))
        return np.argmax(sampled_rewards)

    def update(self, chosen_arm, observed_reward):
        self.counts[chosen_arm] += 1
        count = self.counts[chosen_arm]
        self.arm_rewards[chosen_arm].append(observed_reward)

        arm_mean = np.mean(self.arm_rewards[chosen_arm])
        prior_weight = self.variance / (self.variance + (count * self.prior_var))
        new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
        new_variance = 1 / (1 / self.prior_var + count / self.variance)
        # Update the posterior
        self.means[chosen_arm] = new_mean
        self.variances[chosen_arm] = new_variance

    def batch_update(self,As,Rs):
        # reinitialize
        self.variance = 0.3 ** 2
        self.prior_mean = 0
        self.prior_var = 1
        self.means = np.ones(self.num_A) * self.prior_mean
        self.variances = np.ones(self.num_A) * self.prior_var
        self.counts = np.zeros(self.num_A)
        self.arm_rewards = [[] for _ in range(self.num_A)]

        # use all the data to update
        for i in range(len(As)):
            a,r = As[i], Rs[i]
            self.counts[a] += 1
            self.arm_rewards[a].append(r)
        for a in range(self.num_A):
            arm_mean = np.mean(self.arm_rewards[a]) if len(self.arm_rewards[a]) > 0 else 0.
            prior_weight = self.variance / (self.variance + (self.counts[a] * self.prior_var))
            new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
            new_variance = 1 / (1 / self.prior_var + self.counts[a] / self.variance)
            # Update the posterior
            self.means[a] = new_mean
            self.variances[a] = new_variance

