from .basic_bandit import BasicBandit
import numpy as np

class ConstrainedBandit(BasicBandit):
    """
    A step-based constrained bandit for a constrained cost and reward,
    intended for use with an external environment that provides step-by-step samples.
    
    In each 'round':
      - We plan to sample each active arm 't_r' times total (over the entire round).
      - Once we've allocated those t_r samples per arm, we perform an elimination
        using the EGE rule.
      - Repeat until we finish K-1 rounds or exhaust the budget.
    """

    def __init__(
        self,
        num_arms: int,
        num_objectives: int = 2,
        constraints: list = [1.0],
        total_budget: int = 0,
        sigma: float = 1.0
    ):
        """
        Parameters
        ----------
        num_arms : int
            Number of arms K.
        num_objectives : int
            Dimensionality (D) of the reward vector.
        total_budget : int
            Total budget T of pulls; we do one pull per step.
        sigma : float
            Subgaussian parameter (used if you track confidence, not strictly necessary).
        """
        super().__init__(num_arms=num_arms)
        self.num_arms = num_arms
        self.num_objectives = num_objectives
        self.total_budget = total_budget
        self.constraints = constraints
        print(total_budget)
        self.sigma = sigma

        # Track sums and counts of observations
        self.sum_rewards = np.zeros((num_arms, num_objectives), dtype=np.float64)
        self.num_plays = np.zeros(num_arms, dtype=int)
        
        # Active arms = arms we have not yet definitively eliminated or accepted
        self.active_arms = list(range(num_arms))
        #self.accepted_optimal = set()   # B in the EGE paper notation
        #self.accepted_subopt = set()    # D in the EGE paper notation

        # Current round
        self.current_round = 1
        self.rounds = num_arms - 1  # In SR, we can eliminate exactly 1 arm per round
        self.done = False  # set to True when we can no longer proceed

        # We compute the SR schedule: how many total times we sample each arm up to each round r
        self.n_r = self._sr_schedule()
        # We'll track how many pulls (so far) in the current round for each arm
        self.round_pulls_for_arm = {arm: 0 for arm in self.active_arms}

        # A running counter for total steps used
        self.t = 0

    def _sr_schedule(self):
        """
        Standard 'Successive Rejects' schedule for total samples per arm across rounds.

        n_r[r] = total number of pulls (per active arm) by the end of round r.
        We do not strictly need the exact formula, but commonly:
            n_r[r] = floor( (T - K) / log(K) * 1/(K-r+1) ) + n_r[r-1]
        for r=1..K-1. Implementation can vary.
        """
        K = self.num_arms
        T = self.total_budget
        logK = 0.5 + np.sum([1.0/i for i in range(2, K+1)])  # approximate

        if T < K * logK:

            n_r = np.ones(K, dtype=int)  # index 0..(K-1), so n_r[r] is for round r
            n_r[0] = 0  # round 0, no samples yet
            T = T - K
            c = T / logK if logK > 0 else T
            for r in range(1, K):
                portion = c / (K - r + 1)
                incr = int(np.floor(portion)) if portion > 0 else 0
                n_r[r] += incr
            print(n_r)
        else:
            # If T is too large, we can just use a simple schedule
            n_r = np.zeros(K, dtype=int)
            c = T / logK if logK > 0 else T
            for r in range(1, K):
                n_r[r] = int(np.floor(c / (K - r + 1)))
            print(n_r)

        leftover = self.total_budget - np.sum(n_r) - n_r[-1]
        if leftover > 1:
            n_r[-leftover + 1: ] += 1

        return n_r

    def choose_action(self):
        """
        In a single step, choose the next arm to pull.
        We want to respect the EGE schedule so that each active arm gets t_r[r] samples
        by the end of the current round. Then we do elimination.
        
        Returns
        -------
        arm : int
            The index of the arm to pull next.
            Or None if we are completely done.
        """
        if self.done:
            print('done1')
            return None
        if len(self.active_arms) == 0:
            # No arms to choose from, we must be done
            self.done = True
            print('done2')
            return None

        # Figure out how many pulls we want in the current round
        r = self.current_round
        # total pulls each active arm should have by the end of round r:
        target_pulls_for_round = self.n_r[r] - (self.n_r[r-1] if r > 1 else 0)
        # We see if there's an arm that hasn't reached the round quota yet
        for arm in self.active_arms:
            if self.round_pulls_for_arm[arm] < target_pulls_for_round:
                return arm

        # If all active arms have used up their round quota, we do elimination
        self._round_elimination()

        # Prepare the next round
        self.current_round += 1
        if self.current_round > self.rounds:
            # We can't continue beyond K-1 rounds
            self.done = True
            print('done3')
            return None

        # Reset counters for the new round
        for arm in self.active_arms:
            self.round_pulls_for_arm[arm] = 0

        # Now pick an arm in the new round (assuming we still have arms left)
        if len(self.active_arms) == 0:
            self.done = True
            print('done4')
            return None

        r = self.current_round
        target_pulls_for_round = self.n_r[r] - self.n_r[r-1]
        # We can just choose the first active arm for the next step
        return self.active_arms[0]

    def update(self, arm: int, rewards):
        """
        Record the new multi-objective reward for the chosen arm.
        
        Parameters
        ----------
        arm : int
            Index of the arm just played.
        rewards : dict[str, float] or np.ndarray
            Multi-objective reward. If you have a dict, we can place them in the
            sum_rewards row for the corresponding dimension. Or if it's a D-vector, we can use it directly.
        """
        # If for some reason we got None arm, skip
        if arm is None or self.done:
            return

        # Convert dict of rewards => array
        if isinstance(rewards, dict):
            # If user gave a dict, we assume keys are the reward method names
            reward_array = np.array(list(rewards.values()), dtype=float)
        else:
            # If user already gave a vector-like
            reward_array = np.array(rewards, dtype=float)

        # Update sums
        self.sum_rewards[arm] += reward_array
        self.num_plays[arm] += 1

        # Step-based: increment the round-based counter for that arm
        if arm in self.round_pulls_for_arm:
            self.round_pulls_for_arm[arm] += 1

        self.t += 1
        if self.t >= self.total_budget:
            # If we exhausted the total budget, we are done
            self.done = True

    def _round_elimination(self):
        """
        Perform one elimination step using the EGE rule (Algorithm 1 in the paper).
        We first identify the 'empirically best' arm J among all active arms:
        - If at least one arm appears feasible (constraint dimension <= self.constraints[0]),
            J is the feasible arm with the highest objective mean.
        - Else J is the infeasible arm with the lowest constraint dimension.
        Then for each arm i, we compute an empirical gap \Delta(J, i) that measures
        how 'much worse' i looks compared to J.  Finally, we eliminate the arm i
        with the largest gap \Delta(J, i).
        """

        # If there's only 0 or 1 arms, nothing to eliminate
        if len(self.active_arms) <= 1:
            return

        # Compute sample means for the objective & constraint
        means = np.zeros((len(self.active_arms), self.num_objectives))
        for idx, arm in enumerate(self.active_arms):
            if self.num_plays[arm] > 0:
                means[idx] = self.sum_rewards[arm] / self.num_plays[arm]
            else:
                # If never pulled, treat as 0 or do something benign
                means[idx] = np.zeros(self.num_objectives)

        # Build a quick check for feasibility in this round
        def is_feasible(idx):
            return means[idx, 1] >= self.constraints[0]

        # Identify which arms look feasible
        feasible_indices = [idx for idx in range(len(self.active_arms))
                            if is_feasible(idx)]

        # Pick J = the "empirical best" arm
        if len(feasible_indices) > 0:
            # Among feasible arms, pick one with highest objective
            # The argmax in objective dimension (0)
            feasible_objectives = means[feasible_indices, 0]
            best_feasible_idx = feasible_indices[np.argmax(feasible_objectives)]
            J = best_feasible_idx
        else:
            # No arm looks feasible; pick the one with smallest constraint dimension
            J = np.argmax(means[:, 1])

        # Now compute the gap delta(J, i) for each arm i
        gaps = np.zeros(len(self.active_arms))

        # Check whether J itself is feasible
        J_feasible = is_feasible(J)
        J_obj = means[J, 0]
        J_con = means[J, 1]
        tau = self.constraints[0]

        for i in range(len(self.active_arms)):
            if i == J:
                # We do not eliminate J itself, so set gap= -inf or just 0
                gaps[i] = -9999999.0
                continue

            i_obj = means[i, 0]
            i_con = means[i, 1]
            i_feasible = (i_con >= tau)

            if J_feasible:
                if i_feasible:
                    gaps[i] = J_obj - i_obj  
                else:
                    if i_obj >= J_obj:
                        # deceiver
                        gaps[i] = tau - i_con
                    else:
                        gaps[i] = max(J_obj - i_obj, tau - i_con)
            else:
                if not i_feasible:
                    gaps[i] = J_con - i_con
                else:
                    # i is feasible, J is infeasible
                    print('error: J is infeasible but i is feasible')

        # We eliminate the arm with the largest gap
        to_remove_idx = np.argmax(gaps)
        # print(f"Empirical best arm: {self.active_arms[J]} with reward {means[J]}")
        # print(f"Removing arm {self.active_arms[to_remove_idx]} with gap {gaps[to_remove_idx]} reward {means[to_remove_idx]}")
        # print(f"Is the Removing arm feasible? {is_feasible(to_remove_idx)}")
        arm_to_remove = self.active_arms[to_remove_idx]

        # Remove that arm from self.active_arms
        self.active_arms.remove(arm_to_remove)

        

    def best_arm(self):
        return self.active_arms[0] if len(self.active_arms) > 0 else None

    def reset(self):
        """Reset all counters for a fresh run."""
        self.sum_rewards.fill(0.0)
        self.num_plays.fill(0)
        self.active_arms = list(range(self.num_arms))
        self.accepted_optimal.clear()
        self.accepted_subopt.clear()
        self.current_round = 1
        self.done = False
        self.n_r = self._sr_schedule()
        self.round_pulls_for_arm = {arm: 0 for arm in self.active_arms}
        self.t = 0
    # end class
