import numpy as np
from bandit.mab import Arm, MAB
from bandit.best_reward import SuccessiveEliminationBME


def simple_single_item_matching(player_types, return_pairs=False):

    player_types = np.array(sorted(player_types, reverse=True))

    # positive players are buyers    
    buyer_indices = np.where(player_types > 0)
    buyer_values = player_types[buyer_indices]

    # negative players are sellers
    seller_indices = np.where(player_types < 0)    
    seller_values = player_types[seller_indices]

    # match buyers who are willing to pay high price and sellers who are willing to sell at lower price
    n = min([len(buyer_values), len(seller_values)])
    match = buyer_values[:n] + seller_values[:n]    
    n_matched = sum(match > 0)

    # total values of matched players
    matched_value = sum(match[:n_matched])
    
    if return_pairs:
        # matched pairs
        matched_pairs = [(buyer_indices[0][i], seller_indices[0][i]) for i in range(n_matched)]
        return matched_value, matched_pairs
    else:
        return matched_value


class SimpleGame:

    class Player:
        def __init__(self, types):
            self._types = types

    def __init__(self, n_players, n_types, type_space, single_item_matching, rng=np.random.RandomState(42)):
        self._n_players = n_players
        self._n_types = n_types
        self._type_space = type_space
        self._single_item_matching = single_item_matching
        self._rng = rng
        self._estimated_pivots = None
        self._n_evaluations = 0

        self._players = list()
        for _ in range(n_players):
            # for each player, choose types from type_space
            types = self._rng.choice(type_space, n_types, replace=False)
            player = self.Player(types)
            self._players.append(player)

        self._values = dict()
        # E[w*(t)|t_n]
        self._conditional_value = np.full((n_players, n_types), None)

    # compute values for all combinations of types
    def compute_all_values(self):
        for n in range(self._n_types**self._n_players):
            idx_list = [(n // self._n_types**i) % self._n_types for i in range(self._n_players)]
            types = [p._types[idx_list[i]] for i, p in enumerate(self._players)]
            types = tuple(sorted(types))
            if types in self._values:
                continue
            self._values[types] = self._single_item_matching(types)

    def get_feasibility(self):
        feasibility = sum(self._get_all_min_conditional_values())
        feasibility -= (self._n_players - 1) * self._get_expected_value()
    
        # feasible iff feasibility >= 0
        return feasibility

    def get_individual_rationality(self, player_idx, epsilon=0, delta=0, rho=0, rule="equal", IR=False):
        # get individual rationality (E[w*(t) | t_n] - h_n) of a player for all types t_n
        # - the player is individually rational iff E[w*(t) | t_n] - h_n >= 0 for all t_n 
        if epsilon == 0 and delta == 0:
            h = self.get_constant_pivots(rule, IR)
        else:
            h = self.learn_constant_pivots(epsilon, delta, rule, IR)
        h_n = h[player_idx] + rho / self._n_players
        IR = [self._get_conditional_value(player_idx, t_idx) - h_n for t_idx in range(self._n_types)]
        return IR

    def get_budget_balance(self, epsilon=0, delta=0, rho=0, rule="equal", IR=False):
        # get budget balance (sum_n h_n - (N-1) E[w*(t)])
        # - weakly budget balanced iff sum_n h_n - (N-1) E[w*(t)] >= 0
        if epsilon == 0 and delta == 0:
            h = self.get_constant_pivots(rule, IR)
        else:
            h = self.learn_constant_pivots(epsilon, delta, rule, IR)
        BB = sum(h) - (self._n_players - 1) * self._get_expected_value() + rho
        return BB
    
    def get_constant_pivots(self, rule="equal", IR=False):
        # Get the following constant pivot for each player n:
        #   min_{t_n} E[w*(t)|t_n] - delta_n
        # delta_n is constant for all n when rule = "equal"
        min_values = np.array(self._get_all_min_conditional_values())
        
        total_delta = sum(min_values) - (self._n_players - 1) * self._get_expected_value()
        if IR and total_delta < 0:
            total_delta = 0

        if rule == "equal":
            deltas = np.array([total_delta / self._n_players for _ in self._players])
        else:
            raise NotImplementedError
        
        return min_values - deltas    

    def learn_constant_pivots(self, epsilon, delta, rule="equal", IR=False):
        # learn the following constant pivot for each player n:
        #   min_{t_n} E[w*(t)|t_n] - delta_n
        # delta_n is constant for all n when rule = "equal"
        # Prepare bandit environment and algorithm

        if self._estimated_pivots is not None:
            return self._estimated_pivots

        max_value = self._single_item_matching(self._type_space)

        min_values = list()
        expectations = list()
        for p_idx in range(self._n_players):
            env = SimpleMAB(self, p_idx, self._rng)
            alg = SuccessiveEliminationBME(epsilon, delta, support=[0, max_value])
            alg.run(env, detail=True)
            min_values.append(-alg.get_best_mean_reward())
            self._n_evaluations += alg.get_total_n_sample()

            expect = -np.mean([alg.get_mean_reward(arm) for arm in range(env.n_arms())])
            expectations.append(expect)
        
        # Evaluate E[w*(t)]
        if True:
            # compute from estimated conditional expectations
            expected_value = np.mean(expectations)
        else:
            # estimated by sampling
            env = SimpleSAB(self, self._rng)
            alg = SuccessiveEliminationBME(epsilon, delta, support=[0, max_value])
            alg.run(env, detail=True)
            expected_value = -alg.get_best_mean_reward()
            self._n_evaluations += alg.get_total_n_sample()
        
        total_delta = sum(min_values) - (self._n_players - 1) * expected_value

        if IR and total_delta < 0:
            total_delta = 0
        
        if rule == "equal":
            deltas = np.array([total_delta / self._n_players for _ in self._players])
        else:
            raise NotImplementedError

        self._estimated_pivots = min_values - deltas 

        return self._estimated_pivots  

    def _get_value(self, types):
        types = tuple(sorted(types))
        if types not in self._values:
            self._values[types] = self._single_item_matching(types)
        return self._values[types]

    def get_n_evaluated(self):
        return len(self._values.keys())
            
    # expected value of efficient allocation
    def _get_expected_value(self):
        total_value = 0
        max_n = self._n_types**self._n_players
        for number in range(max_n):
            index_list = [(number // self._n_types**i) % self._n_types for i in range(self._n_players)]
            types = [player._types[index_list[i]] for i, player in enumerate(self._players)]
            total_value += self._get_value(types)
        return total_value / max_n
    
    # conditional expected value of efficient allocation given the type of the given player
    def _get_conditional_value(self, player_idx, type_idx):
        if self._conditional_value[(player_idx, type_idx)] is None:
            total_value = 0
            max_n = self._n_types**(self._n_players - 1)
            for n in range(max_n):
                index_list = [(n // self._n_types**i) % self._n_types for i in range(self._n_players - 1)]
                index_list = index_list[:player_idx] + [type_idx] + index_list[player_idx:]
                types = [player._types[index_list[i]] for i, player in enumerate(self._players)]
                total_value += self._get_value(types)
            self._conditional_value[(player_idx, type_idx)] = total_value / max_n
        return self._conditional_value[(player_idx, type_idx)]
    
    # minimum conditional expected value of efficient allocation given the type of the given player
    def _get_min_conditional_value(self, player_idx):
        conditional_values = [self._get_conditional_value(player_idx, t_idx) for t_idx in range(self._n_types)]
        return min(conditional_values)
    
    def _get_all_min_conditional_values(self):
        return [self._get_min_conditional_value(p_idx) for p_idx in range(self._n_players)]

    def sample_types(self, batch_size, rng):
        types = [rng.choice(p._types, batch_size) for p in self._players]
        return np.vstack(types).T

    def sample_conditional_types(self, player_idx, player_type, batch_size, rng):
        types = [rng.choice(self._players[i]._types, batch_size) for i in range(player_idx)]
        types += [np.array([player_type] * batch_size)]
        types += [rng.choice(self._players[i]._types, batch_size) for i in range(player_idx + 1, self._n_players)]
        return np.vstack(types).T


class SimpleArm(Arm):

    def __init__(self, game, player_idx, player_type, batch_size=10, rng=np.random.default_rng(42)):
        self._game = game
        self._player_idx = player_idx
        self._player_type = player_type
        self._rng = rng
        super().__init__(batch_size, rng)

    def _sample_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self._batch_size
        types = self._game.sample_conditional_types(self._player_idx, self._player_type, batch_size, self._rng)
        self._samples = [-self._game._get_value(t) for t in types]


class SimpleMAB(MAB):

    def __init__(self, game, player_idx, rng, batch_size=10):
        arms = [SimpleArm(game, player_idx, p_type, batch_size=batch_size, rng=rng) for p_type in game._players[player_idx]._types]
        super().__init__(arms)


class SimpleOneArm(Arm):

    def __init__(self, game, batch_size=10, rng=np.random.default_rng(42)):
        self._game = game
        self._rng = rng
        super().__init__(batch_size, rng)

    def _sample_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self._batch_size
        types = self._game.sample_types(batch_size, self._rng)
        self._samples = [-self._game._get_value(t) for t in types]


class SimpleSAB(MAB):
    """ Simple single armed bandit """

    def __init__(self, game, rng, batch_size=10):
        arms = [SimpleOneArm(game, batch_size=batch_size, rng=rng)]
        super().__init__(arms)
