import numpy as np

class ClusterBanditSelector:
    def __init__(self, n_clusters, epsilon=0.3):
        self.n_clusters = n_clusters
        self.epsilon = epsilon
        self.cluster_rewards = [0.0] * n_clusters
        self.cluster_counts = [1e-6] * n_clusters  # Avoid division by zero

        # For vanilla Thompson
        self.cluster_rewards = [0.0] * n_clusters  # sum of rewards


    def select_cluster(self):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.n_clusters)  # Explore
        else:
            return np.argmax(self.get_negated_avg_values())  # Exploit lower value

    def select_cluster_thompson(self):
        samples = []
        for i in range(self.n_clusters):
            mean = -self.cluster_rewards[i] / self.cluster_counts[i]  # negate for lower-is-better
            std = np.sqrt(1.0 / self.cluster_counts[i])
            samples.append(np.random.normal(loc=mean, scale=std))
        return np.argmax(samples)

    def update_thompson(self, cluster_id, reward):
        self.cluster_rewards[cluster_id] += reward
        self.cluster_counts[cluster_id] += 1


    def get_negated_avg_values(self):
        # Lower average value is better, so negate to use argmax
        avg = [self.cluster_rewards[i] / self.cluster_counts[i] for i in range(self.n_clusters)]
        return [-x for x in avg]

    def update(self, cluster_id, value):
        # Store raw value (e.g., loss, novelty score, etc.)
        self.cluster_rewards[cluster_id] += value
        self.cluster_counts[cluster_id] += 1
