from .policy import Policy
import numpy as np

class delayedMMAB(Policy):

    def __init__(self, narms, internal_rank, NUMOFPLAYERS, T,  verbose=False):
        Policy.__init__(self, narms, T, internal_rank, NUMOFPLAYERS)
        self.K0 = narms
        self.name = 'delayedMMAB'
        self.int_rank = internal_rank
        self.num_players = NUMOFPLAYERS
        self.phase = 'exploration'
        self.t_phase = 0
        self.verbose = verbose
        if internal_rank == NUMOFPLAYERS - 1:
            self.is_leader = True
        else:
            self.is_leader = False
        np.random.seed(42)
        self.best_arms_set = [[-1] * NUMOFPLAYERS for _ in range(T)]
        self.best_arms_set[0] = list(np.random.choice(range(narms), NUMOFPLAYERS, replace=False))
        self.expl_arms_set = [[-1] * (narms - NUMOFPLAYERS) for _ in range(T)]
        self.expl_arms_set[0] = list(set(range(narms)) - set(self.best_arms_set[0]))
        self.use = 0
        self.arms_to_remove = []
        self.arms_to_add = []
        self.com_stamp = []
        self.active_arms = np.arange(0, narms)
        self.sums = np.zeros(narms)  # means*npulls
        self.arm_emp_reward = [0 for _ in range(narms)]
        self.arm_rm_temp = []
        self.end = 0
        self.index = 0
        self.t_phase_block = 0
        self.add = 0
        self.remove = 0
        self.bad_arm = []
        self.temp_arm = 0
        self.count_final = np.full(T, '', dtype=object)
        self.p = 0

    def play(self):
        a = 0
        if self.phase == 'exploration':
            if not self.is_leader:
                self.index = ((self.t_phase + self.int_rank) % self.num_players)
                a = self.best_arms_set[self.use][self.index]
            else:
                if self.t_phase < int(self.num_players * np.ceil(np.log(self.T))):
                    self.index = ((self.t_phase + self.int_rank) % self.num_players)
                    a = self.best_arms_set[self.use][self.index]
                else:
                    if self.K != self.num_players:
                        self.index = ((self.t_phase + self.int_rank) % (self.K - self.num_players))
                        a = self.expl_arms_set[self.use][self.index]
                    elif self.bad_arm:
                        a = self.bad_arm[0]
                    else:
                        a = self.temp_arm

        if self.phase == 'communication':
            if not self.is_leader:
                # remove, length = M
                if self.t_phase < self.num_players:
                    self.index = ((self.t_phase_block + self.int_rank) % self.num_players)
                    a = self.best_arms_set[self.use][self.index]
                # add, length = K
                elif self.t_phase < self.K0 + self.num_players:
                    self.index = ((self.t_phase_block + self.int_rank) % self.K0)
                    temp_arm = list(range(self.K0))
                    a = temp_arm[self.index]
                # notify end, length = M
                elif self.t_phase < 2 * self.K0 + self.num_players:
                    self.index = ((self.t_phase_block + self.int_rank) % self.K0)
                    temp_arm = list(range(self.K0))
                    a = temp_arm[self.index]
            else:
                if self.arms_to_remove and self.arms_to_add:
                    # remove, length = M
                    if self.t_phase == 0:
                        self.remove = self.arms_to_remove[-1]
                        self.add = self.arms_to_add[-1]

                    if self.t_phase < self.num_players:
                        a = self.remove
                    # add, length = K
                    elif self.t_phase < self.K0 + self.num_players:
                        a = self.add
                else:
                    if self.t_phase < self.num_players:
                        self.index = ((self.t_phase_block + self.int_rank) % self.num_players)
                        a = self.best_arms_set[self.use][self.index]
                    elif self.t_phase < self.K0 + self.num_players:
                        self.index = ((self.t_phase_block + self.int_rank) % self.K0)
                        temp_arm = list(range(self.K0))
                        a = temp_arm[self.index]

                # notify end, length = M-1
                if self.t_phase < 2 * self.K0 + self.num_players and self.t_phase >= self.K0 + self.num_players:
                    if self.K != self.num_players:
                        self.index = ((self.t_phase_block + self.int_rank) % self.K0)
                        temp_arm = list(range(self.K0))
                        a = temp_arm[self.index]
                    else:
                        self.end += 1
                        a = 0

        # exploitation phase
        if self.phase == 'exploitation':

            if not self.is_leader:
                a = self.best_arms_set[self.use][self.int_rank]
            else:
                a = self.best_arms_set[self.use][self.int_rank]
        return a

    def is_outside_range(self, i, rng):
        a, b = rng
        return not (a <= i <= b)

    def is_inside_range_rm(self, i, ranges):
        for (a, b) in ranges:
            if a <= i <= a + self.num_players - 1:
                return True
        return False

    def is_inside_range_add(self, i, ranges):
        for (a, b) in ranges:
            if a + self.num_players <= i <= a + self.num_players + self.K0 - 1:
                return True
        return False

    def is_inside_range_end(self, i, ranges):
        x = 0
        for (a, b) in ranges:
            if a + self.num_players + self.K0 <= i <= b:
                return True, x
        return False, 0

    def reward_function_l(self, arm, time, reward, collision):
        if all(self.is_outside_range(time, rng) for rng in self.com_stamp):
            if collision == 0:
                self.sums[arm] += reward
                self.npulls[arm] += 1
        x, _ = self.is_inside_range_end(time, self.com_stamp)
        if x and self.K != self.num_players and collision == 0:
            self.sums[arm] += reward
            self.npulls[arm] += 1

    def reward_function_f(self, arm, time, collision):
        if self.is_inside_range_rm(time, self.com_stamp):
            if collision == 1:
                self.arms_to_remove.append(arm)
        if self.is_inside_range_add(time, self.com_stamp):
            if collision == 1:
                self.arms_to_add.append(arm)
        if self.arms_to_add and self.arms_to_remove:
            rm = self.arms_to_remove.pop()
            if rm in self.best_arms_set[self.use]:
                index_to_remove = self.best_arms_set[self.use].index(rm)
                self.best_arms_set[self.use].pop(index_to_remove)
                ad = self.arms_to_add.pop()
                self.best_arms_set[self.use].insert(index_to_remove, ad)
        choice, _ = self.is_inside_range_end(time, self.com_stamp)
        if choice and collision == 1:
            self.phase = 'exploitation'
            self.t_phase = 1

    def update(self):
        if self.phase == 'exploration':
            if not self.is_leader:
                ...
            else:
                if self.K != self.num_players:
                    if np.all(self.npulls[self.active_arms] > 0):
                        b_up = self.sums[self.active_arms] / self.npulls[
                            self.active_arms] + np.sqrt(
                            2 * np.log(self.T) / (self.npulls[self.active_arms]))
                        b_low = self.sums[self.active_arms] / self.npulls[
                            self.active_arms] - np.sqrt(
                            2 * np.log(self.T) / (self.npulls[self.active_arms]))

                        # compute the arms to accept/reject
                        for i, k in enumerate(self.active_arms):
                            better = np.sum(b_low > (b_up[i]))
                            if better >= self.num_players:
                                self.arm_rm_temp.append(self.active_arms[i])
                                print(f'DDSE without delay estimation: successive eliminate {self.active_arms[i]} at time {self.t}')

                    intersecting_arms = set(self.arm_rm_temp).intersection(self.best_arms_set[self.use])
                    if intersecting_arms:

                        arms_actually_remove = set(self.arm_rm_temp) - set(self.best_arms_set[self.use])
                        self.active_arms = self.active_arms[~np.isin(self.active_arms, self.arm_rm_temp)]
                        expl_arms_np = np.array(self.expl_arms_set[self.use])

                        expl_arms_np = expl_arms_np[~np.isin(expl_arms_np, list(arms_actually_remove))]

                        self.expl_arms_set[self.use] = expl_arms_np.tolist()
                        self.bad_arm.extend(intersecting_arms)
                    else:
                        if self.arm_rm_temp:
                            self.temp_arm = self.arm_rm_temp[0]

                        self.active_arms = self.active_arms[~np.isin(self.active_arms, self.arm_rm_temp)]

                        expl_arms_np = np.array(self.expl_arms_set[self.use])

                        expl_arms_np = expl_arms_np[~np.isin(expl_arms_np, self.arm_rm_temp)]

                        self.expl_arms_set[self.use] = expl_arms_np.tolist()
                    self.K = len(self.active_arms)
                    self.arm_rm_temp = []
            self.t_phase += 1

            # end of exploration phase
            if self.t_phase == int(self.num_of_players * self.K0 * np.ceil(np.log(self.T))) - 1:
                if not self.is_leader:
                    ...
                else:
                    divide = 0
                    for i in self.active_arms:
                        if self.npulls[i] > 0:
                            divide += 1

                    if divide == len(self.active_arms):
                        self.arm_emp_reward = [self.sums[i] / self.npulls[i] for i in range(self.K0)]

                    M_minus = set(self.best_arms_set[self.use])
                    available_arms = set(self.active_arms) - M_minus
                    if available_arms:
                        max_reward_arm = max(available_arms, key=lambda x: self.arm_emp_reward[x])
                        if self.bad_arm:
                            if len(self.bad_arm) > 1:
                                self.arms_to_remove.extend(self.bad_arm)
                                rm_arm = self.arms_to_remove[-1]
                                self.bad_arm.pop()
                                self.arms_to_add.append(max_reward_arm)
                                index_to_remove = self.best_arms_set[self.use].index(rm_arm)
                                self.best_arms_set[self.use].pop(index_to_remove)
                                self.best_arms_set[self.use].insert(index_to_remove, max_reward_arm)
                            else:
                                rm_arm = self.bad_arm.pop()
                                self.arms_to_remove.append(rm_arm)
                                self.arms_to_add.append(max_reward_arm)
                                index_to_remove = self.best_arms_set[self.use].index(rm_arm)
                                self.best_arms_set[self.use].pop(index_to_remove)
                                self.best_arms_set[self.use].insert(index_to_remove, max_reward_arm)
                        else:
                            sorted_arms = np.argsort(self.arm_emp_reward)
                            min_reward_arm = min(M_minus, key=lambda x: self.arm_emp_reward[x])
                            if sorted_arms.tolist().index(min_reward_arm) != self.num_players - 1:
                                self.arms_to_remove.append(min_reward_arm)
                                self.arms_to_add.append(max_reward_arm)
                                index_to_remove = self.best_arms_set[self.use].index(min_reward_arm)
                                self.best_arms_set[self.use].pop(index_to_remove)
                                self.best_arms_set[self.use].insert(index_to_remove, max_reward_arm)
                                self.expl_arms_set[self.use] = list(set(self.active_arms) - set(self.best_arms_set[self.use]))

                self.phase = 'communication'
                self.t_phase = 0
                begin_time = self.t + 1
                end_time = self.t + 1 + 2 * self.K0 + self.num_players - 1
                self.com_stamp.append((begin_time, end_time))

        elif self.phase == 'communication':
            if self.t_phase == self.num_players - 1 or self.t_phase == self.K0 + self.num_players - 1:
                self.t_phase_block = 0

            self.t_phase += 1
            self.t_phase_block += 1

            if self.t_phase == 2 * self.K0 + self.num_players:
                self.p += 1
                if not self.is_leader:
                    self.phase = 'exploration'
                    self.t_phase = 0
                    self.t_phase_block = 0
                else:
                    if self.arms_to_add and self.arms_to_remove:
                        self.arms_to_remove.pop()
                        self.arms_to_add.pop()
                    self.phase = 'exploration'

                    if self.end == self.K0:
                        self.phase = 'exploitation'
                    self.t_phase = 0
                    self.t_phase_block = 0

        elif self.phase == 'exploitation':
            self.t_phase += 1
            ...

        self.t += 1
