from .policy import Policy
import numpy as np

class SICMMAB(Policy):
    def __init__(self, narms, internal_rank, NUMOFPLAYERS, T,  verbose=False):
        Policy.__init__(self, narms, T, internal_rank, NUMOFPLAYERS)
        self.K0 = narms
        self.name = 'SICMMAB'
        self.int_rank = internal_rank
        self.changed_rank = internal_rank
        self.M = NUMOFPLAYERS  # number of active players
        self.last_action = internal_rank  # last play for sequential hopping
        self.phase = 'exploration' # Start directly with exploration
        self.t_phase = 0  # step in the current phase
        self.round_number = 0  # phase number of exploration phase
        self.active_arms = np.arange(0, self.K)
        self.sums = np.zeros(self.K)
        self.last_phase_stats = np.zeros(self.K)
        self.verbose = verbose
        self.com_stamp = []
        self.expl_stamp = (0, 2 * self.K0 - 1)
        self.com_update = [(0, 0, 0, 0, 0) for _ in range(self.T)]
        self.index = 0
        self.final_arm = -1

    def play(self):
        """
        return arm to pull based on past information (given in update)
        """
        if self.phase == 'exploration':
            if self.active_arms.size == 0:
                print("active arm is none")
            else:
                self.index = ((self.t_phase + self.int_rank) % self.K)
                a = self.active_arms[self.index]
                return a
        # communication phase
        if self.phase == 'communication':
            if (self.t_phase < (self.int_rank + 1) * (self.M - 1) * self.K *
                (self.round_number + 2)
                    and (self.t_phase >= (self.int_rank) *
                         (self.M - 1) * self.K * (self.round_number + 2))):
                # your turn to communicate
                # determine the number of the bit to send, the channel and the player
                t0 = self.t_phase % (
                    (self.M - 1) * self.K * (self.round_number + 2)
                )  # the actual time step in the communication phase (while giving info)
                b = (int)(
                    t0 %
                    (self.round_number + 2))  # the number of the bit to send

                k0 = (int)(((t0 - b) / (self.round_number + 2)) %
                           self.K)  # the arm to send
                k = self.active_arms[k0]
                if (((int)(self.last_phase_stats[k]) >> b) %
                        2):  # has to send bit 1
                    j = (t0 - b - (self.round_number + 2) * k0) / (
                        (self.round_number + 2) * self.K)  # the player to send
                    j = (int)(j + (j >= self.int_rank))
                    return self.active_arms[j]  # send 1
                else:
                    return self.active_arms[self.int_rank]  # send 0

            else:
                return self.active_arms[self.int_rank]  # receive protocol or wait

        # exploitation phase
        if self.phase == 'exploitation':
            return self.final_arm

    def is_outside_range(self, i, rng):
        a, b = rng
        return not (a <= i <= b)

    def is_inside_expl(self, i, ranges):
        if i in ranges:
            return True
        return False

    def is_inside_com(self, i, ranges):
        for (a, b) in ranges:
            if a <= i <= b:
                t_phase, M, K, round_number, _ = self.com_update[i]
                if (t_phase * M * K) != 0:
                    if (t_phase >= (self.int_rank + 1) * (M - 1) * K * (round_number + 2)
                                or (t_phase < (self.int_rank) * (M - 1) * K * (round_number + 2))):
                        return True
        return False

    def reward_function(self, arm, time, reward, collision):
        if all(self.is_outside_range(time, rng) for rng in self.com_stamp):
            if collision == 0:
                self.sums[arm] += reward
                self.npulls[arm] += 1

        if self.is_inside_expl(time, self.expl_stamp):
            if collision == 0:
                self.last_phase_stats[arm] += reward

        if self.is_inside_com(time, self.com_stamp):
            if collision == 1:
                t_phase, M, K, round_number, active_arms = self.com_update[time]
                t0 = t_phase % ((M - 1) * K * (round_number + 2))
                b = (int)(t0 % (round_number + 2))
                k0 = (int)(((t0 - b) / (round_number + 2)) % K)
                k = active_arms[k0]
                self.sums[k] += ((2 << b) >> 1)

    def update(self):
        """
        Update the information, phase, etc. given the last round information
        """
        if self.phase == 'exploration':
            self.t_phase += 1

            # end of exploration phase
            if self.t_phase == (2 << self.round_number) * self.K:
                self.phase = 'communication'
                self.t_phase = 0
                begin_time = self.t + 1
                end_time = self.t + 1 + (self.M) * (self.M - 1) * self.K * (self.round_number + 2) - 2
                self.com_stamp.append((begin_time, end_time))

        elif self.phase == 'communication':
            a = self.t_phase
            b = self.M
            c = self.K
            d = self.round_number
            f = self.active_arms
            self.com_update[self.t] = (a, b, c, d, f)

            self.t_phase += 1

            if (self.t_phase == (self.M) * (self.M - 1) * self.K * (self.round_number + 2) or self.M == 1):
                reject = []
                accept = []
                if np.all(self.npulls[self.active_arms] > 0):
                    b_up = self.sums[self.active_arms] / self.npulls[
                        self.active_arms] + np.sqrt(
                        2 * np.log(self.T) / (self.npulls[self.active_arms]))
                    b_low = self.sums[self.active_arms] / self.npulls[
                        self.active_arms] - np.sqrt(
                        2 * np.log(self.T) / (self.npulls[self.active_arms]))

                    # compute the arms to accept/reject
                    for i, k in enumerate(self.active_arms):
                        better = np.sum(b_low > (b_up[i]))
                        worse = np.sum(b_up < b_low[i])
                        if better >= self.M:
                            reject.append(k)
                        if worse >= (self.K - self.M):
                            accept.append(k)

                    # update set of active arms
                    for k in reject:
                        self.active_arms = np.setdiff1d(self.active_arms, k)
                    for k in accept:
                        self.active_arms = np.setdiff1d(self.active_arms, k)

                    # update number of active players and arms
                    self.M -= len(accept)
                    self.K -= (len(accept) + len(reject))

                if len(accept) > self.int_rank:  # start exploitation
                    self.phase = 'exploitation'
                    self.t_phase = 0
                    print(f'{self.changed_rank} begin to exploit {accept[self.int_rank]} at time {self.t}')
                    self.final_arm = accept[self.int_rank]

                else:  # new exploration phase and update internal rank
                    self.phase = 'exploration'
                    self.int_rank -= len(accept)
                    self.last_phase_stats = np.zeros(self.K0)
                    self.t_phase = 0
                    begin_time = self.t + 1
                    end_time = self.t + 1 + (2 << (self.round_number + 1)) * self.K - 1
                    self.expl_stamp = (begin_time, end_time)

                self.round_number += 1

        elif self.phase == 'exploitation':
            self.t_phase += 1

        self.t += 1
