from .policy import Policy
import numpy as np

class delayedMMAB_ct(Policy):

    def __init__(self, narms, internal_rank, NUMOFPLAYERS, T,  verbose=False):
        Policy.__init__(self, narms, T, internal_rank, NUMOFPLAYERS)
        self.K0 = narms
        self.name = 'delayedMMAB_ct'
        self.int_rank = internal_rank
        self.num_players = NUMOFPLAYERS
        self.phase = 'exploration'
        self.t_phase = 0
        self.verbose = verbose
        if internal_rank == NUMOFPLAYERS - 1:
            self.is_leader = True
        else:
            self.is_leader = False

        np.random.seed(42)

        self.best_arms_set = [[-1] * NUMOFPLAYERS for _ in range(T)]
        self.best_arms_set[0] = list(np.random.choice(range(narms), NUMOFPLAYERS, replace=False))

        self.expl_arms_set = [[-1] * (narms - NUMOFPLAYERS) for _ in range(T)]
        self.expl_arms_set[0] = list(set(range(narms)) - set(self.best_arms_set[0]))

        self.use = 0

        self.arms_to_remove = []
        self.arms_to_add = []

        self.com_stamp = []  #
        self.active_arms = np.arange(0, narms)
        self.sums = np.zeros(narms)
        self.arm_emp_reward = [0 for _ in range(narms)]
        self.arm_rm_temp = []
        self.end = 0
        self.index = 0
        self.t_phase_block = 0
        self.add = 0
        self.remove = 0
        self.bad_arm = []
        self.temp_arm = 0
        self.count_final = np.full(T, '', dtype=object)
        self.p = 0
        self.qumo = self.K - self.num_players

    def play(self):
        a = 0
        if self.phase == 'exploration':
            if not self.is_leader:
                self.index = ((self.t + self.int_rank) % self.num_players)
                a = self.best_arms_set[self.use][self.index]
            else:
                if self.t % self.K + 1 <= self.num_players:
                    self.index = ((self.t + self.int_rank) % self.num_players)
                    # print(f'index {self.index}, best {self.best_arms_set[self.use]}')
                    a = self.best_arms_set[self.use][self.index]
                else:
                    if self.K != self.num_players:
                        self.index = ((self.t + self.int_rank) % (self.qumo))
                        # print(f'index {self.index}, exp {self.expl_arms_set[self.use]}')
                        # print(f'check length: expl = {len(self.expl_arms_set[self.use])}, qumo={self.qumo}')
                        a = self.expl_arms_set[self.use][self.index]
                    elif self.bad_arm:
                        a = self.bad_arm[0]
                    else:
                        a = self.temp_arm
        # exploitation phase
        if self.phase == 'exploitation':

            if not self.is_leader:
                a = self.best_arms_set[self.use][self.int_rank]
            else:
                a = self.best_arms_set[self.use][self.int_rank]
            if self.t_phase == 0:
                print(f'Player {self.int_rank} selects {a} exploitation phase')
        # print(f'Player {self.int_rank} select arm {a} at {self.t}')
        return a

    def reward_function_l(self, arm, time, reward, collision):
        if collision == 0:
            self.sums[arm] += reward
            self.npulls[arm] += 1


    def send_set(self):
        return self.best_arms_set[self.use]

    def receive_set(self, alist):
        self.best_arms_set[self.use][:] = alist

    def send_end(self):
        return self.K == self.num_players

    def receive_end(self):
        self.phase = 'exploitation'
        self.t = 0
        self.t_phase = 0

    def update(self):
        self.t += 1
        if self.phase == 'exploration':
            if not self.is_leader:
                ...
            else:
                # update confidence intervals
                if self.K != self.num_players:
                    if np.all(self.npulls[self.active_arms] > 0):
                        b_up = self.sums[self.active_arms] / self.npulls[
                            self.active_arms] + np.sqrt(
                            2 * np.log(self.T) / (self.npulls[self.active_arms]))
                        b_low = self.sums[self.active_arms] / self.npulls[
                            self.active_arms] - np.sqrt(
                            2 * np.log(self.T) / (self.npulls[self.active_arms]))

                        # compute the arms to accept/reject
                        for i, k in enumerate(self.active_arms):
                            better = np.sum(b_low > (b_up[i]))
                            if better >= self.num_players:
                                self.arm_rm_temp.append(self.active_arms[i])
                                print(f'successive eliminate {self.active_arms[i]} at time {self.t}')

                    intersecting_arms = set(self.arm_rm_temp).intersection(self.best_arms_set[self.use])
                    self.active_arms = [arm for arm in self.active_arms if arm not in set(self.arm_rm_temp)]
                    self.expl_arms_set[self.use] = [arm for arm in self.expl_arms_set[self.use] if arm not in set(self.arm_rm_temp)]


                    if intersecting_arms:
                        self.best_arms_set[self.use] = [
                            arm for arm in self.best_arms_set[self.use] if arm not in intersecting_arms
                        ]
                        num_to_add = len(intersecting_arms)

                        avg_rewards = {arm: self.sums[arm] / self.npulls[arm] for arm in self.active_arms}

                        sorted_arms = sorted(
                            self.active_arms,
                            key=lambda arm: avg_rewards[arm],
                            reverse=True
                        )

                        potential_arms = [
                            arm for arm in sorted_arms
                            if arm not in self.best_arms_set[self.use] and arm not in set(self.arm_rm_temp)
                        ]

                        arms_to_add = potential_arms[:num_to_add]
                        self.best_arms_set[self.use].extend(arms_to_add)

                        self.expl_arms_set[self.use] = [
                            arm for arm in self.active_arms if arm not in self.best_arms_set[self.use]
                        ]

                    self.qumo = len(self.expl_arms_set[self.use])
                    self.K = len(self.active_arms)
                    self.arm_rm_temp = []


            self.t_phase += 1

            if self.K == self.num_of_players:
                self.phase = 'exploitation'
                self.t_phase = 0

        elif self.phase == 'exploitation':
            self.t_phase += 1
