import numpy as np
from . import arms
import math


class MAB:
    """Multi-armed bandits environnement
    
    Args:
        arms (list of arms.Arm)
    """

    def __init__(self, arms):
        self.arms = arms
        self.nb_arms = len(arms)
        self.means = np.array([arm.mean for arm in self.arms])
        self.sorted_means = np.sort(self.means)[::-1]

    def m_best_arms_means(self, n_best_arms):
        return self.sorted_means[:n_best_arms]

    def m_worst_arms_means(self, n_best_arms):
        return self.sorted_means[n_best_arms:]

    def last_best_arm_mean(self, n_best_arms):
        return self.sorted_means[n_best_arms-1]
        
    def generate_reward(self, arm):
        return self.arms[arm].sample()
    
    def __repr__(self):
        return f"MAB({self.arms})"

    def to_latex(self):
        return ", ".join([arm.to_latex() for arm in self.arms])

class BernoulliMAB(MAB):
    """Bernoulli MAB

    Args:
        means (list of float): vector of Bernoulli's means
    """

    def __init__(self, means):
        super().__init__([arms.Bernoulli(p) for p in means])
    
    def __repr__(self):
        return f"BernoulliMAB({self.means})"


class GaussianMAB(MAB):
    """Gaussian MAB

    Args:
        means (list of float): list of means for each Gaussian arm
        variances (list of float): list of variances for each Gaussian arm
    """

    def __init__(self, means, variances):
        assert len(means) == len(variances), "Each mean must have a corresponding variance."
        super().__init__([arms.Gaussian(mu, var) for mu, var in zip(means, variances)])

    def __repr__(self):
        return f"GaussianMAB(means={self.means}, variances={[arm.variance for arm in self.arms]})"
