from .basic_bandit import BasicBandit
import numpy as np


class EGE(BasicBandit):
    """
    A step-based Empirical Gap Elimination (EGE) bandit for multi-objective rewards,
    intended for use with an external environment that provides step-by-step samples.
    
    In each 'round':
      - We plan to sample each active arm 't_r' times total (over the entire round).
      - Once we've allocated those t_r samples per arm, we perform an elimination
        using the EGE rule.
      - Repeat until we finish K-1 rounds or exhaust the budget.
    """

    def __init__(
        self,
        num_arms: int,
        num_objectives: int,
        total_budget: int,
        sigma: float = 1.0
    ):
        """
        Parameters
        ----------
        num_arms : int
            Number of arms K.
        num_objectives : int
            Dimensionality (D) of the reward vector.
        total_budget : int
            Total budget T of pulls; we do one pull per step.
        sigma : float
            Subgaussian parameter (used if you track confidence, not strictly necessary).
        """
        super().__init__(num_arms=num_arms)
        self.num_arms = num_arms
        self.num_objectives = num_objectives
        self.total_budget = total_budget
        print(total_budget)
        self.sigma = sigma

        # Track sums and counts of observations
        self.sum_rewards = np.zeros((num_arms, num_objectives), dtype=np.float64)
        self.num_plays = np.zeros(num_arms, dtype=int)
        
        # Active arms = arms we have not yet definitively eliminated or accepted
        self.active_arms = list(range(num_arms))
        self.accepted_optimal = set()   # B in the EGE paper notation
        self.accepted_subopt = set()    # D in the EGE paper notation

        # Current round
        self.current_round = 1
        self.rounds = num_arms - 1  # In SR, we can eliminate exactly 1 arm per round
        self.done = False  # set to True when we can no longer proceed

        # We compute the SR schedule: how many total times we sample each arm up to each round r
        self.n_r = self._sr_schedule()
        # We'll track how many pulls (so far) in the current round for each arm
        self.round_pulls_for_arm = {arm: 0 for arm in self.active_arms}

        # A running counter for total steps used
        self.t = 0

    def _sr_schedule(self):
        """
        Standard 'Successive Rejects' schedule for total samples per arm across rounds.

        n_r[r] = total number of pulls (per active arm) by the end of round r.
        We do not strictly need the exact formula, but commonly:
            n_r[r] = floor( (T - K) / log(K) * 1/(K-r+1) ) + n_r[r-1]
        for r=1..K-1. Implementation can vary.
        """
        K = self.num_arms
        T = self.total_budget
        logK = 0.5 + np.sum([1.0/i for i in range(2, K+1)])  # approximate

        if T < K * logK:

            n_r = np.ones(K, dtype=int)  # index 0..(K-1), so n_r[r] is for round r
            n_r[0] = 0  # round 0, no samples yet
            T = T - K
            c = T / logK if logK > 0 else T
            for r in range(1, K):
                portion = c / (K - r + 1)
                incr = int(np.floor(portion)) if portion > 0 else 0
                n_r[r] = incr
            print(n_r)
        else:
            # If T is too large, we can just use a simple schedule
            n_r = np.zeros(K, dtype=int)
            c = T / logK if logK > 0 else T
            for r in range(1, K):
                n_r[r] = int(np.floor(c / (K - r + 1)))
            print(n_r)

        leftover = self.total_budget - np.sum(n_r) - n_r[-1]
        if leftover > 1:
            n_r[-leftover + 1: ] += 1

        return n_r

    def choose_action(self):
        """
        In a single step, choose the next arm to pull.
        We want to respect the EGE schedule so that each active arm gets t_r[r] samples
        by the end of the current round. Then we do elimination.
        
        Returns
        -------
        arm : int
            The index of the arm to pull next.
            Or None if we are completely done.
        """
        if self.done:
            print('done1')
            return None
        if len(self.active_arms) == 0:
            # No arms to choose from, we must be done
            self.done = True
            print('done2')
            return None

        # Figure out how many pulls we want in the current round
        r = self.current_round
        # total pulls each active arm should have by the end of round r:
        target_pulls_for_round = self.n_r[r] - (self.n_r[r-1] if r > 1 else 0)
        # We see if there's an arm that hasn't reached the round quota yet
        for arm in self.active_arms:
            if self.round_pulls_for_arm[arm] < target_pulls_for_round:
                return arm

        # If all active arms have used up their round quota, we do elimination
        self._round_elimination()

        # Prepare the next round
        self.current_round += 1
        if self.current_round > self.rounds:
            # We can't continue beyond K-1 rounds
            self.done = True
            print('done3')
            return None

        # Reset counters for the new round
        for arm in self.active_arms:
            self.round_pulls_for_arm[arm] = 0

        # Now pick an arm in the new round (assuming we still have arms left)
        if len(self.active_arms) == 0:
            self.done = True
            print('done4')
            return None

        r = self.current_round
        target_pulls_for_round = self.n_r[r] - self.n_r[r-1]
        # We can just choose the first active arm for the next step
        return self.choose_action() # self.active_arms[0]
    
    def update(self, arm: int, rewards):
        """
        Record the new multi-objective reward for the chosen arm.
        
        Parameters
        ----------
        arm : int
            Index of the arm just played.
        rewards : dict[str, float] or np.ndarray
            Multi-objective reward. If you have a dict, we can place them in the
            sum_rewards row for the corresponding dimension. Or if it's a D-vector, we can use it directly.
        """
        # If for some reason we got None arm, skip
        if arm is None or self.done:
            return

        # Convert dict of rewards => array
        if isinstance(rewards, dict):
            # If user gave a dict, we assume keys are the reward method names
            reward_array = np.array(list(rewards.values()), dtype=float)
        else:
            # If user already gave a vector-like
            reward_array = np.array(rewards, dtype=float)

        # Update sums
        self.sum_rewards[arm] += reward_array
        self.num_plays[arm] += 1

        # Step-based: increment the round-based counter for that arm
        if arm in self.round_pulls_for_arm:
            self.round_pulls_for_arm[arm] += 1

        self.t += 1
        if self.t >= self.total_budget:
            # If we exhausted the total budget, we are done
            self.done = True

    def _round_elimination(self):
        # print('do elimination')
        # print('active: ', self.active_arms)
        # print('accept: ', self.accepted_optimal)
        """
        After finishing the round’s allocated samples for every active arm, 
        eliminate exactly one arm using EGE’s gap-based rule.
        """
        if len(self.active_arms) <= 1:
            return  # if we only have 1 or 0 arms, no elimination needed

        means = np.zeros((len(self.active_arms), self.num_objectives))
        for idx, arm in enumerate(self.active_arms):
            if self.num_plays[arm] > 0:
                means[idx] = self.sum_rewards[arm] / self.num_plays[arm]
            else:
                # If never pulled, default to zeros (or random)
                means[idx] = np.zeros(self.num_objectives)
        print('sum: ', self.sum_rewards)
        print('means:', means)

        # Identify the empirical Pareto set among active arms
        is_pareto = np.ones(len(self.active_arms), dtype=bool)
        for i in range(len(self.active_arms)):
            for j in range(len(self.active_arms)):
                if i == j:
                    continue
                # Check if means[j] strictly dominates means[i]
                if np.all(means[j] >= means[i]) and np.any(means[j] > means[i]):
                    is_pareto[i] = False
                    break

        # Compute the gap for each arm i
        gaps = np.zeros(len(self.active_arms))

        def min_differences(arr1, arr2):
            return np.min(arr2 - arr1)

        def max_differences(arr1, arr2):
            return np.max(arr1 - arr2)

        # Precompute M(i, j) and m(i, j)
        m_matrix = np.zeros((len(self.active_arms), len(self.active_arms)))
        M_matrix = np.zeros((len(self.active_arms), len(self.active_arms)))
        for i in range(len(self.active_arms)):
            for j in range(len(self.active_arms)):
                if i == j:
                    m_matrix[i,j] = 0
                    M_matrix[i,j] = 0
                else:
                    m_matrix[i,j] = min_differences(means[i], means[j])
                    M_matrix[i,j] = max_differences(means[i], means[j])

        # For sub-opt arms: gap(i) = max_{j} m(i,j)
        # For (empirically) Pareto arms: gap(i) = "delta" => e.g. min_{j} max(0, M(j,i))
        idx_pareto = np.where(is_pareto)[0]
        # print('idx_pareto: ', idx_pareto)
        idx_subopt = np.where(~is_pareto)[0]
        # print('idx_subopt: ', idx_subopt)

        for i in idx_subopt:
            gaps[i] = np.max(m_matrix[i,:])  # max_j m(i,j)

        for i in idx_pareto:
            # A simplified approach: measure how close i is to being dominated by any other j
            # gap(i) = min_{j != i} max(0, M(j,i))
            # Cf. EGE paper for the full formula. This simplified approach still works well.
            vals = []
            for j in range(len(self.active_arms)):
                if j == i:
                    continue
                vals.append(max(0, M_matrix[j,i]))
            if len(vals) > 0:
                gaps[i] = np.min(vals)
            else:
                gaps[i] = 0.0

        # We remove the arm that has the LARGEST gap:
        #   If there's a tie, remove sub-opt among those ties first (paper's tie-breaking).
        sorted_idx = np.argsort(gaps)  # ascending
        idx_max = sorted_idx[-1]
        max_gap = gaps[idx_max]

        # gather any arms that tie for the largest gap
        tie_indices = [idx_max]
        for x in reversed(sorted_idx[:-1]):
            if abs(gaps[x] - max_gap) < 1e-12:
                tie_indices.append(x)
            else:
                break

        # among tie_indices, remove a sub-opt arm first if any
        subopt_in_tie = [ii for ii in tie_indices if ii in idx_subopt]
        if len(subopt_in_tie) > 0:
            idx_remove_local = subopt_in_tie[0]
        else:
            idx_remove_local = tie_indices[0]

        arm_remove = self.active_arms[idx_remove_local]

        # Classify it (accepted-opt or accepted-subopt)
        if idx_remove_local in idx_pareto:
            self.accepted_optimal.add(arm_remove)
        else:
            self.accepted_subopt.add(arm_remove)

        # print('active: ', self.active_arms)
        # print('remove: ', arm_remove)
        # print('accept: ', self.accepted_optimal)
        # print('num_plays: ', self.num_plays)
        
        # Remove from the active set
        self.active_arms.remove(arm_remove)
        # Also remove from round_pulls_for_arm
        if arm_remove in self.round_pulls_for_arm:
            del self.round_pulls_for_arm[arm_remove]

    def best_arm(self):
        """
        If you require a single best arm, this is not well-defined for multi-objective.
        We'll just return the 'last active arm' if there's only one left,
        else we could pick e.g. the earliest accepted in self.accepted_optimal.
        """
        if len(self.active_arms) == 1:
            return self.active_arms[0]
        elif len(self.accepted_optimal) > 0:
            return list(self.accepted_optimal)  # pick any
        else:
            # fallback
            return None

    @property
    def pareto_front(self):
        """
        The final recommended Pareto set is the union of:
          - remaining active arms
          - accepted_optimal arms
        """
        return set(self.active_arms).union(self.accepted_optimal)

    def reset(self):
        """Reset all counters for a fresh run."""
        self.sum_rewards.fill(0.0)
        self.num_plays.fill(0)
        self.active_arms = list(range(self.num_arms))
        self.accepted_optimal.clear()
        self.accepted_subopt.clear()
        self.current_round = 1
        self.done = False
        self.n_r = self._sr_schedule()
        self.round_pulls_for_arm = {arm: 0 for arm in self.active_arms}
        self.t = 0
    # end class
