from .policy import Policy
import numpy as np

class delayedMMAB_pro(Policy):

    def __init__(self, narms, internal_rank, NUMOFPLAYERS, T,  verbose=False):
        Policy.__init__(self, narms, T, internal_rank, NUMOFPLAYERS)
        self.K0 = narms
        self.name = 'delayedMMAB_pro'
        self.int_rank = internal_rank
        self.num_players = NUMOFPLAYERS
        self.phase = 'exploration'
        self.t_phase = 0
        self.verbose = verbose
        if internal_rank == NUMOFPLAYERS - 1:
            self.is_leader = True
        else:
            self.is_leader = False
        np.random.seed(42)
        self.best_arms_set = np.full((T, NUMOFPLAYERS), -1, dtype=int)
        self.best_arms_set[0] = np.random.choice(range(narms), NUMOFPLAYERS, replace=False)
        self.expl_set = np.setdiff1d(np.arange(narms), self.best_arms_set[0], assume_unique=True).tolist()
        self.best_update = self.best_arms_set[0].copy()
        self.expl_update = list(range(narms))
        self.arms_to_remove = []
        self.arms_to_add = []
        self.sent = []
        self.com_stamp = []
        self.f_arm_rm = np.full(self.T, -1, dtype=int)
        self.f_arm_add = np.full(self.T, -1, dtype=int)
        self.active_arms = np.arange(0, narms)
        self.sums = np.zeros(narms)
        self.arm_emp_reward = [0 for _ in range(narms)]
        self.arm_rm_temp = []
        self.end = 0
        self.index = 0
        self.t_phase_block = 0
        self.add = 0
        self.remove = 0
        self.bad_arm = []
        self.temp_arm = -1
        self.delays = np.array([], dtype=int)
        self.d_mean = 0
        self.d_var = 0
        self.p = 0
        self.q = 0
        self.package = []
        self.next_min_c = 0
        self.max_p = 999999
        self.x = 0
        self.no_col = [0 for _ in range(self.T)]
        self.position = 0
        self.enter = 0
        self.pac_end = np.zeros((self.T, 2), dtype=int)
        self.notify_set = np.arange(self.num_of_players)

    def play(self):
        a = 0
        if self.phase == 'exploration':
            if not self.is_leader:
                self.index = ((self.t_phase + self.int_rank) % self.num_players)
                a = self.best_arms_set[self.p - self.x][self.index]
            else:
                if self.K != self.num_players:
                    if self.t_phase < int(self.num_players * np.ceil(np.log(self.T))):
                        self.index = ((self.t_phase + self.int_rank) % self.num_players)
                        a = self.best_arms_set[self.p - self.x][self.index]
                    else:
                        self.index = ((self.t_phase + self.int_rank) % (self.K - self.num_players))
                        a = self.expl_set[self.index]
                else:
                    self.index = ((self.t_phase + self.int_rank) % self.num_players)
                    a = self.best_arms_set[self.p - self.x][self.index]
        # communication phase
        if self.phase == 'communication':
            if self.max_p == 999999:
                if not self.is_leader:
                    # remove, length = M
                    if self.t_phase < self.num_players:
                        self.index = ((self.t_phase_block + self.int_rank) % self.num_players)
                        a = self.best_arms_set[self.p - self.x][self.index]
                    # add, length = K
                    elif self.t_phase < self.K0 + self.num_players:
                        self.index = ((self.t_phase_block + self.int_rank) % self.K0)
                        temp_arm = list(range(self.K0))
                        a = temp_arm[self.index]
                    # notify end, length = M
                    elif self.t_phase < self.K0 + 2 * self.num_players:
                        self.index = ((self.t_phase_block + self.int_rank) % self.num_players)
                        a = self.notify_set[self.index]
                else:
                    if self.arms_to_remove and self.arms_to_add:
                        # remove, length = M
                        if self.t_phase == 0:
                            self.remove = self.arms_to_remove[-1]
                            self.add = self.arms_to_add[-1]

                        if self.t_phase < self.num_players:
                            a = self.best_arms_set[self.p-self.x][self.position]
                        # add, length = K
                        elif self.t_phase < self.K0 + self.num_players:
                            a = self.add
                    else:
                        if self.t_phase < self.num_players:
                            self.index = ((self.t_phase_block + self.int_rank) % self.num_players)
                            a = self.best_arms_set[self.p - self.x][self.index]
                        elif self.t_phase < self.K0 + self.num_players:
                            self.index = ((self.t_phase_block + self.int_rank) % self.K0)
                            temp_arm = list(range(self.K0))
                            a = temp_arm[self.index]

                    # notify end, length = M-1, leader传给follower
                    if self.t_phase < self.K0 + 2 * self.num_players and self.t_phase >= self.K0 + self.num_players:
                        if self.K != self.num_players:
                            self.index = ((self.t_phase_block + self.int_rank) % self.num_players)
                            a = self.notify_set[self.index]
                        else:
                            self.end += 1
                            a = self.notify_set[0]

            else:
                self.index = ((self.t_phase + self.int_rank) % self.num_players)
                a = self.best_arms_set[self.p - self.x][self.index]

        # exploitation phase
        if self.phase == 'exploitation':
            if not self.is_leader:
                a = self.best_arms_set[self.max_p - self.x][self.int_rank]
            else:
                a = self.best_update[self.int_rank]

        return a

    def is_outside_range(self, i, rng):
        a, b = rng
        return not (a <= i <= b)

    def is_inside_range_rm(self, i, ranges):
        j = 0
        for (a, b) in ranges:
            if a <= i <= a + self.num_players - 1:
                return True, j
            j += 1
        return False, 0

    def is_inside_range_add(self, i, ranges):
        j = 0
        for (a, b) in ranges:
            if a + self.num_players <= i <= a + self.num_players + self.K0 - 1:
                return True, j
            j += 1
        return False, 0

    def f_is_inside_range_end(self, i, ranges, collision):
        j = 0
        for (a, b) in ranges:
            if a + self.num_players + self.K0 <= i <= b:
                return True, j
            j += 1
        return False, 0

    def l_is_inside_range_end(self, i, ranges):
        for (a, b) in ranges:
            if a + self.num_players + self.K0 <= i <= b:
                return True
        return False

    def f_final_update(self):
        if np.all(self.best_arms_set[self.max_p] == -1):
            i = 0
            while i < len(self.package):
                rm, add, p = self.package[i]
                if p >= self.max_p:
                    self.package.pop(i)
                    i += 1
                else:
                    if p == self.next_min_c:
                        if rm == -2 and add == -2:
                            self.best_arms_set[p + 1] = self.best_arms_set[p].copy()
                        elif add not in self.best_arms_set[p]:
                            self.best_arms_set[p + 1] = self.best_arms_set[p].copy()
                            mask = np.arange(self.best_arms_set.shape[1]) != rm
                            self.best_arms_set[p + 1] = np.where(mask, self.best_arms_set[p + 1], add)

                        self.next_min_c += 1
                        self.package.pop(i)
                    else:
                        i += 1

    def reward_function_l(self, arm, time, reward, collision):
        if all(self.is_outside_range(time, rng) for rng in self.com_stamp):
            if collision == 0:
                self.sums[arm] += reward
                self.npulls[arm] += 1

        if self.l_is_inside_range_end(time, self.com_stamp):
            if self.K != self.num_players and collision == 0:
                self.sums[arm] += reward
                self.npulls[arm] += 1

    def reward_function_f(self, arm, time, collision):
        choice_i, i = self.is_inside_range_rm(time, self.com_stamp)
        if choice_i:
            if collision == 1:
                index = (time - self.com_stamp[i][0] + self.int_rank) % self.num_players
                self.f_arm_rm[i] = index
                if self.f_arm_add[i] != -1 and self.f_arm_rm[i] != -1:
                    self.package.append((index, self.f_arm_add[i], i))
                    self.f_arm_add[i] = -1
                    self.f_arm_rm[i] = -1
            else:
                self.no_col[i] += 1
                if self.no_col[i] == self.K0 + self.num_players - 1:
                    self.package.append((-2, -2, i))

        choice_j, j = self.is_inside_range_add(time, self.com_stamp)
        if choice_j:
            if collision == 1:
                self.f_arm_add[j] = arm
                if self.f_arm_rm[j] != -1 and arm != -1:
                    self.package.append((self.f_arm_rm[j], self.f_arm_add[j], j))
                    self.f_arm_add[j] = -1
                    self.f_arm_rm[j] = -1
            else:
                self.no_col[j] += 1
                if self.no_col[j] == self.K0 + self.num_players - 1:
                    self.package.append((-2, -2, j))

        choice_l, l = self.f_is_inside_range_end(time, self.com_stamp, collision)
        if choice_l:
            if collision == 0:
                self.pac_end[l, 0] += 1
            else:
                self.pac_end[l, 1] += 1

    def f_update_best_arm(self):
        i = 0
        while i < len(self.package):
            rm, add, p = self.package[i]
            if p >= self.max_p:
                self.package.pop(i)
                i += 1
            elif p == self.next_min_c:
                if rm == -2 and add == -2:
                    self.best_arms_set[p + 1] = self.best_arms_set[p].copy()
                elif add not in self.best_arms_set[p]:
                    self.best_arms_set[p + 1] = self.best_arms_set[p].copy()
                    self.best_arms_set[p + 1][rm] = add
                self.next_min_c += 1
                self.package.pop(i)
            else:
                i += 1

    def est_delay(self, pre_t, t):
        current_delay = t - pre_t
        self.delays = np.append(self.delays, current_delay)

        if len(self.delays) > 1:
            self.d_mean = np.mean(self.delays)
            self.d_var = np.var(self.delays, ddof=1)
        else:
            self.d_mean = current_delay
            self.d_var = 0

    def find_q(self, p):
        q = 0
        if p == 0:
            return 0
        elif self.d_var == 0:
            return p

        log_T = np.log(self.T)
        log_num_players_term = np.log((self.num_players - 1) * (self.K0 + 2 * self.num_players) * self.T)

        if p > 0 and self.d_var != 0:
            t_hat = np.ceil(2 * self.d_mean + np.sqrt(2 * self.d_var * log_num_players_term)) + (
                        p - q) * self.K0 * self.num_of_players * np.ceil(log_T) + (self.K0 + 2 * self.num_players) * (
                                p - q) - 1

            if self.t >= t_hat:
                return 0

            while self.t < t_hat:
                q += 1
                t_hat = np.ceil(2 * self.d_mean + np.sqrt(2 * self.d_var * log_num_players_term)) + (
                            p - q) * self.K0 * self.num_of_players * np.ceil(log_T) + (self.K0 + 2 * self.num_players) * (
                                    p - q) - 1
                if q > p:
                    return p
        return q

    def update(self):
        if self.phase == 'exploration':
            if not self.is_leader:
                ...
            else:
                if self.K != self.num_players and np.all(self.npulls[self.active_arms] > 0):
                    pulls_sqrt_term = np.sqrt(2 * np.log(self.T) / self.npulls[self.active_arms])
                    mean_rewards = self.sums[self.active_arms] / self.npulls[self.active_arms]

                    b_up = mean_rewards + pulls_sqrt_term
                    b_low = mean_rewards - pulls_sqrt_term

                    better_matrix = b_up[:, np.newaxis] < b_low
                    better_count = np.sum(better_matrix, axis=1)

                    self.arm_rm_temp = self.active_arms[better_count >= self.num_players].tolist()

                    for arm in self.arm_rm_temp:
                        print(f'DDSE: successive eliminate {arm} at time {self.t}')
                    if self.arm_rm_temp:
                        intersecting_arms = set(self.arm_rm_temp).intersection(self.best_update)

                        if intersecting_arms:
                            self.active_arms = np.array(list(set(self.active_arms) - set(self.arm_rm_temp)))
                            self.bad_arm.extend(intersecting_arms)
                        else:
                            self.temp_arm = self.arm_rm_temp[0]
                            self.active_arms = np.array(list(set(self.active_arms) - set(self.arm_rm_temp)))
                            if self.temp_arm in self.expl_update:
                                self.expl_update.remove(self.temp_arm)
                            if self.temp_arm in self.expl_set:
                                self.expl_set.remove(self.temp_arm)

                        self.K = len(self.active_arms)
                        self.arm_rm_temp = []
            self.t_phase += 1

            # end of exploration phase
            if self.t_phase == int(self.num_of_players * self.K0 * np.ceil(np.log(self.T))):
                self.phase = 'communication'
                self.t_phase = 0
                self.com_stamp.append((self.t + 1, self.t + self.K0 + 2 * self.num_players))

                if not self.is_leader:
                    if self.max_p == 999999:
                        result = np.where((self.pac_end[:self.p, 0] == self.num_players - 1) & (self.pac_end[:self.p, 1] == 1))[0]
                        if len(result) > 0:
                            self.max_p = result[0] + 1

                    if self.max_p != 999999:
                        x = self.find_q(self.max_p)
                        if x == 0:
                            self.x = x
                            self.phase = 'exploitation'
                            self.t_phase = 0

                else:
                    if self.end == self.num_players and self.max_p == 999999:
                        self.max_p = self.p
                        print(f'L: max p ={self.max_p}, x = {self.x}, t = {self.t}')
                    if self.max_p != 999999:
                        self.x = self.find_q(self.max_p)
                        if self.x == 0:
                            self.phase = 'exploitation'
                    else:
                        if np.all(self.npulls[self.active_arms] > 0):
                            self.arm_emp_reward = np.where(self.npulls > 0, self.sums / self.npulls, 0)
                            M_minus = set(self.best_update)
                            available_arms = list(set(self.active_arms) - M_minus)

                            if available_arms:
                                max_reward_arm = max(available_arms, key=lambda x: self.arm_emp_reward[x])

                                if self.bad_arm:
                                    self.arms_to_remove.extend(self.bad_arm)
                                    self.arms_to_add.append(max_reward_arm)

                                    rm_arm = self.arms_to_remove[0]
                                    self.bad_arm = []

                                    index_to_remove = np.where(self.best_update == rm_arm)[0][0]
                                    self.best_update[index_to_remove] = max_reward_arm
                                    self.position = index_to_remove

                                elif self.arms_to_remove:
                                    self.arms_to_add.append(max_reward_arm)

                                    rm_arm = self.arms_to_remove[0]

                                    index_to_remove = np.where(self.best_update == rm_arm)[0][0]
                                    self.best_update[index_to_remove] = max_reward_arm
                                    self.position = index_to_remove

                                else:
                                    sorted_arms = np.argsort(self.arm_emp_reward)
                                    min_reward_arm = min(M_minus, key=lambda x: self.arm_emp_reward[x])

                                    if sorted_arms.tolist().index(min_reward_arm) != self.num_players - 1:
                                        self.arms_to_remove.append(min_reward_arm)
                                        self.arms_to_add.append(max_reward_arm)

                                        index_to_remove = np.where(self.best_update == min_reward_arm)[0][0]
                                        self.best_update[index_to_remove] = max_reward_arm
                                        self.position = index_to_remove

        elif self.phase == 'communication':
            self.t_phase += 1
            self.t_phase_block += 1

            if self.t_phase == self.num_players or self.t_phase == self.K0 + self.num_players:
                self.t_phase_block = 0

            if self.t_phase == self.K0 + 2 * self.num_players:
                if not self.is_leader:
                    self.f_update_best_arm()
                else:
                    rm, ad = -1, -1
                    if self.arms_to_add and self.arms_to_remove:
                        rm = self.arms_to_remove.pop(0)
                        ad = self.arms_to_add.pop(0)

                    if rm > -1 and ad > -1:
                        if (rm in self.best_arms_set[self.p]) and (ad not in self.best_arms_set[self.p]):
                            self.best_arms_set[self.p + 1] = self.best_arms_set[self.p].copy()
                            index_to_remove = np.where(self.best_arms_set[self.p + 1] == rm)[0][0]
                            self.best_arms_set[self.p + 1][index_to_remove] = ad

                        else:
                            self.best_arms_set[self.p + 1] = self.best_arms_set[self.p].copy()
                    elif rm == -1 and ad == -1:
                        self.best_arms_set[self.p + 1] = self.best_arms_set[self.p].copy()
                    else:
                        self.best_arms_set[self.p + 1] = self.best_arms_set[self.p].copy()

                    intersecting_arms = set(self.best_arms_set[self.p+1-self.x]).intersection(self.expl_update)
                    if intersecting_arms:
                        self.expl_set = list(set(self.expl_update) - set(intersecting_arms))
                    elif set(self.expl_set) != set(self.expl_update):
                        self.expl_set = self.expl_update[:]

                self.phase = 'exploration'
                self.t_phase = 0
                self.t_phase_block = 0
                self.p += 1
                self.x = self.find_q(self.p)

        elif self.phase == 'exploitation':
            if not self.is_leader:
                self.f_final_update()

            self.t_phase += 1

        self.t += 1
