    
import numpy as np
from numpy import *
import matplotlib.pyplot as plt
from numpy.ma.core import flatten_structured_array
from tqdm import tqdm
import matplotlib
import datetime
import torch
import math
# ODA
class Decentralized(object):

    def __init__(self, horizon, trial, num_player, num_arm, player_ranking, arm_ranking, player_mean, arm_capacity):
        self.path = './ResultsData/decen/'
        self.horizon = horizon
        self.trials = trial
        self.p_lambda = 0.1
        self.epsilon = 10**(-10)

        self.varEpsilon = 0.2

        self.num_players = num_player
        self.num_arms = num_arm
        self.players_ranking = player_ranking
        self.arms_rankings = arm_ranking  
        self.players_mean = player_mean
        self.arms_capacity = arm_capacity

        self.beta = 1/(2*self.num_arms)

        self.gamma = 2

        self.pessimal_matching = self.get_pessimal_matching_substitute(self.players_ranking).tolist()
        print("pessimal matching", self.pessimal_matching)

        self.optimal_matching = [player_ranking[p_idx][0] for p_idx in range(self.num_players)]
   
        self.optimal_rewards = np.zeros(self.num_players)
        for p_idx in range(self.num_players):
            self.optimal_rewards[p_idx] = self.players_mean[p_idx][self.optimal_matching[p_idx]]
        
        print("optimal matching", self.optimal_matching)
        print("optimal rewards", self.optimal_rewards)
  
        self.calculate_matching_gap()

    def calculate_matching_gap(self):
        total_gap = 0
        for p_idx in range(self.num_players):
            opt_arm = self.optimal_matching[p_idx]
            pes_arm = self.get_arm_from_matching(self.pessimal_matching, p_idx)
            
            if opt_arm is not None and pes_arm is not None:
                player_gap = self.players_mean[p_idx][opt_arm] - self.players_mean[p_idx][pes_arm]
                total_gap += player_gap
                print(f"Player {p_idx}: optimal arm {opt_arm} value {self.players_mean[p_idx][opt_arm]:.2f}, "
                      f"pessimal arm {pes_arm} value {self.players_mean[p_idx][pes_arm]:.2f}, gap {player_gap:.2f}")
            else:
                print(f"Player {p_idx}: Missing match in optimal or pessimal")
        
        print(f"Total value gap between optimal and pessimal: {total_gap:.2f}")


    def get_arm_from_matching(self, matching, p_idx):

        if isinstance(matching[p_idx], (list, np.ndarray)):
            if len(matching[p_idx]) > 0:
                return matching[p_idx][0]
            else:
                return None
        else:
            return matching[p_idx]


    def get_pessimal_matching_substitute(self, players_rankings):
        base_set = [list(range(self.num_players)) for a_idx in range(self.num_arms)]
        flag = 0 
        while flag == 0:
            matched = [[] for _ in range(self.num_arms)]
            matching = [[] for _ in range(self.num_players)]

            propose_list = [[] for a_idx in range(self.num_arms)]
            for a_idx in range(self.num_arms):
                propose_list[a_idx] = self.accept_player_substitute(a_idx, base_set[a_idx])
            
                for p_idx in propose_list[a_idx]:
                    matching[p_idx].append(a_idx)
                
            p_choicess = [[] for _ in range(self.num_players)]
            for p_idx in range(self.num_players):
                p_choices = matching[p_idx]
                p_choicess[p_idx] = matching[p_idx]

                if len(p_choices) != 0:    
                    p_choice = next((x for x in players_rankings[p_idx] if x in matching[p_idx]), None)

                    matching[p_idx] = [p_choice]
                   
                    for a_idx in p_choices:
                
                        if a_idx == p_choice:
                            matched[a_idx].append(p_idx)
                        else:
                            base_set[a_idx].remove(p_idx)
        
            flag = 1
            for p_idx in range(self.num_players):
                if len(p_choicess[p_idx]) > 1:
                    flag = 0
                    break                      
        return np.squeeze(matching)

    def accept_player_substitute(self, arm, given_player_set):
        if len(given_player_set) == 0:
            return []
        
        given_player_set.sort()
        given_player_set = tuple(given_player_set)

        returnset = self.arms_rankings[arm][given_player_set]

        return list(returnset)

    def isUnstableSubstitute(self, arm_matching):
        player_matching = np.ones(self.num_players) * (-1)
        for p_idx in range(self.num_players):
            for a_idx in range(self.num_arms):
                if p_idx in arm_matching[a_idx]:
                    player_matching[p_idx] = a_idx
        
        has_blocking_pair = False
        for p_idx in range(self.num_players):
            if player_matching[p_idx] != -1:
                for possible_arm_rank in range(self.players_ranking[p_idx].index(player_matching[p_idx])):
                    arm = self.players_ranking[p_idx][possible_arm_rank]
                    
                    newset = arm_matching[arm].copy()
                    newset = list(newset)
                    newset.append(p_idx)
                    
                    if p_idx in self.accept_player_substitute(arm, newset):
                        has_blocking_pair = True
                        break
            else:
                for possible_arm_rank in range(self.num_arms):
                    arm = self.players_ranking[p_idx][possible_arm_rank]
                    
                    newset = arm_matching[arm].copy()
                    newset = list(newset)
                    newset.append(p_idx)
                    
                    if p_idx in self.accept_player_substitute(arm, newset):
                        has_blocking_pair = True
                        break
            
            if has_blocking_pair:
                return 1  

        for p_idx in range(self.num_players):
            if int(player_matching[p_idx]) != self.optimal_matching[p_idx]:
                return 1 
        
        return 0  
    def decen_elimination_substitue(self, N, C, delta):
       # Using to save data
        regrets_trials = np.zeros([self.num_players, self.trials, self.horizon])
        rewards_trials = np.zeros([self.num_players, self.trials, self.horizon])
        unstable_trials = np.zeros([self.trials, self.horizon])
        
        for trial in tqdm(range(self.trials), ascii=True, desc="Running the Decentralized-Elimination"):
            
            unstable_one_trial = np.zeros(self.horizon)
            regrets_one_trial = np.zeros([self.num_players, self.horizon])
            rewards_one_trial = np.zeros([self.num_players, self.horizon])
           
            players_es_mean = [np.zeros(self.num_arms) for j in range(self.num_players)]
            players_count = [np.zeros(self.num_arms) for j in range(self.num_players)]
            players_flag = [np.ones(self.num_arms, int) for j in range(self.num_players)]
            
            arm_matchings = [[] for a_idx in range(self.num_arms)]
            
            Pull_arms = [[] for i in range(self.num_players)]
            Pull_arms_each_round = np.ones([self.num_players, self.horizon], int) * (-1)
                            
            Available_players = [[[] for j in range(self.num_arms)] for i in range(self.num_players)]
            for player in range(self.num_players):
                for a_idx in range(self.num_arms):
                    for p_idx in range(self.num_players):
                        Available_players[player][a_idx].append(p_idx) 

            Plausible_arms = [[] for i in range(self.num_players)]
            for p_idx in range(self.num_players):
                for a_idx in range(self.num_arms):
                    if p_idx in self.accept_player_substitute(a_idx, Available_players[p_idx][a_idx]):
                        Plausible_arms[p_idx].append(a_idx)
           
            for round in range(self.horizon):
                At = np.ones(self.num_players) * (-1)
         
                for p_idx in range(self.num_players):
                    Whether_for_detect = 0
                    
                    for a_idx in Plausible_arms[p_idx]:
                        flag = 0

                        if round >= self.num_arms:
                            t = round - 1
                            while t >= round - self.num_arms:
                                if int(Pull_arms_each_round[p_idx][t]) == int(a_idx):
                                    flag = 1
                                   
                                t = t - 1

                            if flag == 0:
                                At[p_idx] = a_idx
                                Whether_for_detect = 1
                                break

                    if Whether_for_detect == 0:
                        min_T = float('inf')
                        for a_idx in Plausible_arms[p_idx]:
                            if players_count[p_idx][a_idx] <= min_T:
                                min_T = players_count[p_idx][a_idx]
                                At[p_idx] = a_idx

                # update reward
                successfully_pulled_players = []
                current_matching = [[] for j in range(self.num_arms)]

                for a_idx in range(self.num_arms):
                    propose_player = []

                    for p_idx in range(self.num_players):
                        if At[p_idx] == a_idx:
                           propose_player.append(p_idx)
                    
                    p_idxs = self.accept_player_substitute(a_idx, propose_player)
                    current_matching[a_idx] = p_idxs
                   
                    for p_idx in p_idxs:
                        if a_idx not in Pull_arms[p_idx]:
                            Pull_arms[p_idx].append(a_idx)
                        Pull_arms_each_round[p_idx][round] = a_idx

                        reward = np.random.normal(loc=self.players_mean[p_idx][a_idx], scale=1.0, size=None)

                        players_count[p_idx][a_idx] += 1
                        players_es_mean[p_idx][a_idx] += (reward - players_es_mean[p_idx][a_idx]) / players_count[p_idx][a_idx]

                        successfully_pulled_players.append(p_idx)

                        opt_arm = self.optimal_matching[p_idx]
                        opt_reward = self.players_mean[p_idx][opt_arm]
                        actual_reward = self.players_mean[p_idx][a_idx]
                        
                        regrets_one_trial[p_idx][round] = np.maximum(0, opt_reward - actual_reward)
                        rewards_one_trial[p_idx][round] = actual_reward

                unstable_one_trial[round] = self.isUnstableSubstitute(current_matching)

                # update flags
                for p_idx in range(self.num_players):
                    max_lcb = float('-inf')
                    for a_idx in Plausible_arms[p_idx]:
                        if players_count[p_idx][a_idx] != 0:
                            lcb = players_es_mean[p_idx][a_idx] - 2.5 * np.sqrt(2 * np.log(round + 1) / ((players_count[p_idx][a_idx])))
                            if lcb > max_lcb:
                                max_lcb = lcb
                    
                    for a_idx in Plausible_arms[p_idx]: 
                        if players_count[p_idx][a_idx] != 0:       
                            ucb = players_es_mean[p_idx][a_idx] + 2.5 * np.sqrt(2 * np.log(round + 1) / ((players_count[p_idx][a_idx]))) 
                            if ucb < max_lcb - 0.1:
                                if p_idx in Available_players[p_idx][a_idx]:
                                    Available_players[p_idx][a_idx].remove(p_idx) 

                    if p_idx not in successfully_pulled_players:
                        opt_arm = self.optimal_matching[p_idx]
                        opt_reward = self.players_mean[p_idx][opt_arm]
                        regrets_one_trial[p_idx][round] = opt_reward  
                        rewards_one_trial[p_idx][round] = 0

                # Detect other players 
                if round >= self.num_arms * 2:
                    for other_p_idx in range(self.num_players):
                        for arm in Pull_arms[other_p_idx]:
                            flag = 0
                            t = round
                            while t > round - self.num_arms * 2:
                                if int(Pull_arms_each_round[other_p_idx][t]) == int(arm):
                                    flag = 1
                                t = t - 1

                            if flag == 0:
                                for p_idx in range(self.num_players):
                                    if other_p_idx in Available_players[p_idx][arm]:
                                        Available_players[p_idx][arm].remove(other_p_idx) 

                for p_idx in range(self.num_players):
                    Plausible_arms[p_idx] = []
                    for a_idx in range(self.num_arms):
                        if p_idx in self.accept_player_substitute(a_idx, Available_players[p_idx][a_idx]):
                            Plausible_arms[p_idx].append(a_idx)              

            # Save Data
            for i in range(self.num_players):
                regrets_trials[i][trial] = regrets_one_trial[i]
                rewards_trials[i][trial] = rewards_one_trial[i]

            unstable_trials[trial] = unstable_one_trial
        
        mean_cum_regret = np.mean(np.cumsum(regrets_trials, axis=2), axis=1)
        mean_regret_across_players = np.mean(mean_cum_regret, axis=0)
        mean_cum_unstable = np.mean(np.cumsum(unstable_trials, axis=1), axis=0)
        
        return mean_regret_across_players, mean_cum_unstable