from .basic_bandit import BasicBandit
import numpy as np

class UniformBandit(BasicBandit):
    """
    A baseline bandit that pulls each arm uniformly.
    """

    def __init__(self, num_arms: int, num_objectives: int = 2):
        """
        Parameters
        ----------
        num_arms : int
            Number of arms.
        num_objectives : int, optional
            Number of objectives (default is 2).
        """
        super().__init__(num_arms)
        self.num_objectives = num_objectives
        self.current_arm = 0
        self.sum_rewards = np.zeros((num_arms, num_objectives), dtype=np.float64)
        self.num_plays = np.zeros(num_arms, dtype=int)

    def update(self, play: int, rewards: np.ndarray):
        """
        Update the reward statistics for the given arm.

        Parameters
        ----------
        play : int
            The index of the arm that was played.
        rewards : np.ndarray
            The reward received from playing the arm (multi-objective).
        """
        # print(rewards)
        if isinstance(rewards, dict):
            rewards = list(rewards.values())
        reward_array = np.array(rewards, dtype=float)
        if reward_array.shape[0] != self.num_objectives:
            raise ValueError(f"Expected reward with {self.num_objectives} objectives, but got {reward_array.shape[0]}.")
        self.sum_rewards[play] += reward_array
        self.num_plays[play] += 1

    def choose_action(self) -> int:
        """
        Choose the next arm to pull uniformly.

        Returns
        -------
        int
            The index of the arm to pull next.
        """
        arm = self.current_arm
        self.current_arm = (self.current_arm + 1) % self.num_arms
        return arm

    def best_arm(self) -> int:
        """
        Return the arm with the highest average reward for the first objective.

        Returns
        -------
        int
            The index of the best arm.
        """
        avg_rewards = self.sum_rewards / np.maximum(self.num_plays[:, None], 1)  # Avoid division by zero
        return int(np.argmax(avg_rewards[:, 0]))  # Use the first objective for simplicity

    @property
    def pareto_front(self):
        """
        Calculate the Pareto front based on the current estimated rewards.

        Returns
        -------
        set
            A set of arm indices that form the Pareto front.
        """
        avg_rewards = self.sum_rewards / np.maximum(self.num_plays[:, None], 1)  # Avoid division by zero
        is_pareto = np.ones(self.num_arms, dtype=bool)  # Initially assume all arms are Pareto optimal

        for i in range(self.num_arms):
            for j in range(self.num_arms):
                if i == j:
                    continue
                # Check if arm j dominates arm i
                if np.all(avg_rewards[j] >= avg_rewards[i]) and np.any(avg_rewards[j] > avg_rewards[i]):
                    is_pareto[i] = False
                    break

        return {arm for arm, is_optimal in enumerate(is_pareto) if is_optimal}

    def reset(self):
        """
        Reset the bandit to its initial state.
        """
        self.current_arm = 0
        self.sum_rewards.fill(0.0)
        self.num_plays.fill(0)


class ParetoFrontUniformBandit(UniformBandit):
    """
    A bandit that outputs the Pareto front based on constraints.
    """

    @property
    def pareto_front(self):
        """
        Override to calculate the Pareto front with additional constraints.
        """
        avg_rewards = self.sum_rewards / np.maximum(self.num_plays[:, None], 1)  # Avoid division by zero
        is_pareto = np.ones(self.num_arms, dtype=bool)

        for i in range(self.num_arms):
            for j in range(self.num_arms):
                if i == j:
                    continue
                # Check if arm j dominates arm i
                if np.all(avg_rewards[j] >= avg_rewards[i]) and np.any(avg_rewards[j] > avg_rewards[i]):
                    is_pareto[i] = False
                    break

        return [arm for arm, is_optimal in enumerate(is_pareto) if is_optimal]


class ConstrainedBestArmUniformBandit(UniformBandit):
    """
    A bandit that outputs the best arm under constraints.
    """

    def __init__(self, num_arms: int, num_objectives: int = 2, constraints: list = None):
        super().__init__(num_arms, num_objectives)
        self.constraints = constraints if constraints else []

    def best_arm(self) -> int:
        """
        Override to return the best arm under constraints.
        """
        avg_rewards = self.sum_rewards / np.maximum(self.num_plays[:, None], 1)  # Avoid division by zero
        valid_indices = np.arange(len(avg_rewards))  # Start with all indices

        for constraint in self.constraints:
            # Find indices that satisfy the constraint
            valid_indices = valid_indices[avg_rewards[valid_indices, 1] <= constraint]

        # If there are valid indices, find the best arm among them
        if len(valid_indices) > 0:
            best_index = valid_indices[np.argmax(avg_rewards[valid_indices, 0])]
            return int(best_index)
        else:
            return -1  # No valid arm found