import numpy as np
from collections import deque
from typing import List


class FullSensingMultiPlayerMAB(object):
    def __init__(self, means, nplayers, horizon, policy, delay_mean, delay_std, reward='Gaussian', max_delay_count=30,
                 **kwargs):
        self.max_delay_count = max_delay_count

        self.K = len(means)
        self.means = np.array(means)
        self.horizon = horizon
        self.M = nplayers
        self.feedback = np.zeros((self.horizon, self.K, 2))
        self.reward_type = reward

        self.players = [
            policy(narms=self.K, T=self.horizon, internal_rank=i, NUMOFPLAYERS=self.M, **kwargs) for i in
            range(nplayers)
        ]

        self.delays = np.abs(np.random.normal(delay_mean, delay_std, size=(self.M, horizon))).astype(int)

        self.queues: List[List[deque]] = [[deque() for _ in range(self.horizon)] for _ in range(self.M)]

        self.future_rewards = np.empty((self.M, self.horizon, self.max_delay_count, 4),
                                       dtype=object)
        self.future_rewards.fill(None)
        self.alist = []
        self.end = False

    def simulate_single_step_rewards(self):
        if self.reward_type == 'Gaussian':
            rewards = np.random.normal(self.means, 1)
            return np.maximum(rewards, 0)
        return np.random.binomial(1, self.means)

    def simulate_single_step(self, plays):
        rews = self.simulate_single_step_rewards()
        unique, counts = np.unique(plays, return_counts=True)
        collisions = unique[counts > 1]
        cols = np.array([p in collisions for p in plays])
        rewards = rews[plays] * (1 - cols)

        return list(zip(rews[plays], cols)), rewards, cols

    def simulate(self, horizon=None, exit_condition=None):
        rewards = []
        play_history = []

        T = horizon

        for t in range(T):
            plays = np.zeros(self.M)
            plays = [(int)(player.play()) for player in self.players]

            obs, rw, cols = self.simulate_single_step(plays)

            for i in range(self.M):
                arm = plays[i]
                delay = self.delays[i][t]
                if t + delay < T:
                    pre_time = t
                    r = float(rw[i])
                    c = bool(cols[i])

                    next_free_index = np.where(self.future_rewards[i][t + delay, :, 0] == None)[0][0]

                    if next_free_index >= self.max_delay_count:
                        raise ValueError(f"Exceed {self.max_delay_count}")

                    self.future_rewards[i][t + delay, next_free_index] = (pre_time, arm, r, c)

                valid_entries = self.future_rewards[i][t][self.future_rewards[i][t, :, 0] != None]
                for pre_time, arm, reward, collision in valid_entries:

                    self.queues[i][t].append((pre_time, arm, reward, collision))

                while len(self.queues[i][t]) > 0:
                    pre_time, arm, reward, collision = self.queues[i][t].popleft()
                    self.feedback[t][arm] = (reward, collision)
                    if self.players[i].name == 'SICMMAB':
                        self.players[i].reward_function(arm, pre_time, reward, collision)
                    elif self.players[i].name == 'delayedMMAB_pro':
                        self.players[i].est_delay(pre_time, t)
                        if i == self.M - 1:
                            self.players[i].reward_function_l(arm, pre_time, reward, collision)
                        else:
                            self.players[i].reward_function_f(arm, pre_time, collision)
                    elif self.players[i].name == 'delayedMMAB_ct':
                        if i == self.M - 1:
                            self.players[i].reward_function_l(arm, pre_time, reward, collision)
                            self.alist[:] = self.players[i].send_set()
                            self.end = self.players[i].send_end()
                        else:
                            if self.alist:
                                self.players[i].receive_set(self.alist)
                            if self.end:
                                self.players[i].receive_end()

                    else:
                        if i == self.M - 1:
                            self.players[i].reward_function_l(arm, pre_time, reward, collision)
                        else:
                            self.players[i].reward_function_f(arm, pre_time, collision)

                if self.players[i].name == 'DPE':
                    self.players[i].update(self.feedback)
                else:
                    self.players[i].update()

            rewards.append(np.sum(rw))
            play_history.append(plays)

            if exit_condition is not None:
                if exit_condition(self.players):
                    T = t + 1
                    break

        top_means = -np.partition(-self.means, self.M)[:self.M]

        best_case_reward = np.sum(top_means) * np.arange(1, T + 1)
        cumulated_reward = np.cumsum(rewards)

        regret = best_case_reward - cumulated_reward
        self.regret = (regret, best_case_reward, cumulated_reward)
        self.top_means = top_means
        return regret, play_history

    def get_players(self):
        return self.players
