import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib
import datetime
import torch
import math
from queue import Queue
import os
import itertools 

class Indifference(object):

    def __init__(self, horizon, trial, num_player, num_arm, arm_preferences, player_mean):
        self.path = './ResultsData/decen/'
        self.horizon = horizon
        self.trials = trial

        self.p_lambda = 0.1
        self.epsilon = 10**(-10)

        # phased ETC algorithm
        self.varEpsilon = 0.2

        self.num_players = num_player
        self.num_arms = num_arm
        self.arms_preferrences = arm_preferences
        self.players_mean = player_mean
        self.arms_rankings = np.zeros([num_arm, num_player])
        self.arms_rankings = self.arms_rankings.astype(int)

        random_tie_breaking = self.arms_preferrences.copy()
        for i in range(self.num_arms):
            for j in range(self.num_players):
                random_tie_breaking[i][j] += np.random.rand() / 1000
        for j in range(self.num_arms):
            self.arms_rankings[j] = np.argsort(-random_tie_breaking[j])

        # UCB-D4
        self.beta = 1/(2*self.num_arms)
        self.gamma = 2
        
        self.pessimal_matching = self.get_pessimal_matching()
        # print("arms_preferrences",'\n' , self.arms_preferrences)
        # print("players_mean",'\n', self.players_mean)
        # print("players_mean", self.players_mean)
        print("pessimal matching",self.pessimal_matching)

        '''
        At=np.zeros(self.num_players)
        for a,p in enumerate(self.pessimal_matching):
                At[p]=a
        print("Pessimal Matching Player:",At)
        '''

    def isUnstable(self, arm_matching):
        arm_matching = arm_matching.tolist()
        
        player_matching = np.ones(self.num_players) * (-1)
        player_matching = player_matching.astype(int)
        for p_idx in range(self.num_players):
            if p_idx in arm_matching:
                player_matching[p_idx] = arm_matching.index(p_idx)

        if -1 in player_matching:
            return 1
        
        # print(player_matching)

        # find blocking pair
        for p_idx in range(self.num_players):
            for p_idy in range(self.num_players):
                if p_idx == p_idy:
                    continue 
                a_px = player_matching[p_idx]
                a_py = player_matching[p_idy]
                if self.players_mean[p_idx][a_px] < self.players_mean[p_idx][a_py] and self.arms_preferrences[a_py][p_idy] < self.arms_preferrences[a_py][p_idx]:
                    # print([p_idx, p_idy, a_px, a_py])
                    return 1
        arm = 1
        return 0
    
    def Gale_Shapley(self, player_ranking):
            # propose_order records the order players should follow while proposing
        init_propose_order = np.zeros(self.num_players, int)
        propose_order = init_propose_order
        # matched record whether a specific player is matched or not
        matched = np.zeros(self.num_players, bool)
        # matching records the choice of a player for a specific arm
        matching = [[] for _ in range(self.num_arms)]

        # Terminates if all matched
        while np.sum(matched) != self.num_players:

            # players propose at the same time
            for p_idx in range(self.num_players):
                if not matched[p_idx]:
                    # p_proposal is the index of an arm
                    # propose_order is the vector, p_o[i] is the order of player i's next proposal
                    p_proposal = player_ranking[p_idx][propose_order[p_idx]]
                    matching[p_proposal].append(p_idx)

            # arms choose its player
            for a_idx in range(self.num_arms):
                a_choices = matching[a_idx]

                if len(a_choices) != 0:    
                    # each arm chooses the its most preferable one
                    a_choice = next((x for x in self.arms_rankings[a_idx] if x in matching[a_idx]), None)
                    # update arm's choice where there should only be one left
                    matching[a_idx] = [a_choice]
                    # update player's state of matched
                    for p_idx in a_choices:
                        matched[p_idx] = (p_idx == a_choice)
                        propose_order[p_idx] += (1 - (p_idx == a_choice))
    
        return np.squeeze(matching)

    def Next_permutation(a):
        i = len(a) - 2
        while not (i < 0 or a[i] < a[i+1]):
            i -= 1
            if i < 0:
                return False
            # else
        j = len(a) - 1
        while not (a[j] > a[i]):
            j -= 1
        (a[i], a[j]) = (a[j], a[i])
        a[i+1:] = reversed(a[i+1:])
        return True

    def get_pessimal_matching(self):
        worst = np.ones([self.num_players]) * (10000000.0)
        per = list(range(self.num_players))
        while (1):
            A = np.ones(self.num_arms) * (-1) 
            A = A.astype(int)
            for i in range(self.num_players):
                A[per[i]] = i
            if(self.isUnstable(A) == 0):
                for i in range(self.num_arms):
                    if(A[i] != -1):
                        if(self.players_mean[A[i]][i] < worst[A[i]]):
                            worst[A[i]] = self.players_mean[A[i]][i]
            if(Indifference.Next_permutation(per) == 0):
                break
        return worst

    def run_AP(self,delta):
        regrets_trials = np.zeros([self.num_players, self.trials, self.horizon])
        # rewards_trials = np.zeros([self.num_players, self.trials, self.horizon])
        unstable_trials = np.zeros([self.trials, self.horizon])
        for trial in tqdm(range(self.trials), ascii=True, desc="Running the AP"):

            unstable_one_trial = np.zeros(self.horizon)
            regrets_one_trial = np.zeros([self.num_players, self.horizon])
            rewards_one_trial = np.zeros([self.num_players, self.horizon])

            player_es_mean = np.zeros([self.num_players, self.num_arms])
            players_count = np.zeros([self.num_players, self.num_arms])
            LCB = np.ones([self.num_players, self.num_arms]) * int(-1e9)
            UCB = np.ones([self.num_players, self.num_arms]) * int(1e9)
            player_index = np.zeros([self.num_arms, self.num_players])
            player_index = player_index.astype(int)

            for round in range(self.horizon):
                random_tie_breaking = self.arms_preferrences.copy()
                arm_propose_now = np.ones(self.num_players) * int(-1)
                arm_propose_now = arm_propose_now.astype(int)
                for i in range(self.num_arms):
                    for j in range(self.num_players):
                        random_tie_breaking[i][j] += np.random.rand() / 1000
                for j in range(self.num_arms):
                    player_index[j] = np.argsort(-random_tie_breaking[j])
                
                At = np.ones(self.num_players)*(-1) 
                At = At.astype(int)

                unmatched_arm = Queue()
                for j in range(self.num_arms):
                    unmatched_arm.put(j)

                while unmatched_arm.qsize() != 0:

                    Arm = int(unmatched_arm.get())
                    arm_propose_now[Arm] += 1
                    if arm_propose_now[Arm] == self.num_players:
                        continue

                    Player = int(player_index[Arm][arm_propose_now[Arm]])
                    if At[Player] == -1 or UCB[Player][At[Player]] < LCB[Player][Arm] or (UCB[Player][Arm] > LCB[Player][At[Player]] and players_count[Player][Arm] < players_count[Player][At[Player]]):
                        if At[Player] != -1:
                            unmatched_arm.put(At[Player])
                        At[Player] = Arm
                    else:
                        unmatched_arm.put(Arm)
                    
                    
                last_pulled = np.ones(self.num_arms) * (-1)
                for p_idx in range(self.num_players):
                    last_pulled[At[p_idx]] = p_idx

                last_pulled = last_pulled.astype(int)

                for p_idx in range(self.num_players):
                    if last_pulled[At[p_idx]] == p_idx:
                        regrets_one_trial[p_idx][round] = max (0, self.pessimal_matching[p_idx] - self.players_mean[p_idx][At[p_idx]])
                        rewards_one_trial[p_idx][round] = self.players_mean[p_idx][At[p_idx]]

                        # reward = np.random.binomial(1, self.players_mean[p_idx][At[p_idx]])
                        reward=np.random.normal(loc=self.players_mean[p_idx][At[p_idx]], scale=1.0, size=None)
                        players_count[p_idx][At[p_idx]] += 1
                        player_es_mean[p_idx][At[p_idx]] += (reward - player_es_mean[p_idx][At[p_idx]]) / players_count[p_idx][At[p_idx]]

                        tmp = np.sqrt(2 * np.log(self.horizon) / np.log(np.exp(1)) / players_count[p_idx][At[p_idx]])
                        LCB[p_idx][At[p_idx]] = player_es_mean[p_idx][At[p_idx]] - tmp
                        UCB[p_idx][At[p_idx]] = player_es_mean[p_idx][At[p_idx]] + tmp

                        
                    else:
                        rewards_one_trial[p_idx][round] = 0
                unstable_one_trial[round] = self.isUnstable(last_pulled)
            for i in range(self.num_players):
                regrets_trials[i][trial] = regrets_one_trial[i]
            unstable_trials[trial] = unstable_one_trial
            # print(LCB) 
            # print(UCB)
            # print(player_index)
            # print(players_count)
        # print(np.sum(regrets_trials))
        # print(np.sum(unstable_trials))
        np.savez('Armpropose_regret.npz', regret=regrets_trials)
        np.savez('Armpropose_Unstable.npz', unstable=unstable_trials)
        # np.savez('Armpropose_'+'delta_'+str(delta)+'_regret.npz', regret=regrets_trials)
        # np.savez('Armpropose_'+'delta_'+str(delta)+'_Unstable.npz', unstable=unstable_trials)
        # np.savez('Armpropose_'+'N_'+str(self.num_players)+'_Unstable.npz', unstable=unstable_trials)

    def run_phasedETC(self,Beta,delta): 
        regrets_trials = np.zeros([self.num_players, self.trials, self.horizon])
        # rewards_trials = np.zeros([self.num_players, self.trials, self.horizon])
        unstable_trials = np.zeros([self.trials, self.horizon])

        for trial in tqdm(range(self.trials), ascii=True, desc="Running the decentralized phasedETC"):
            unstable_one_trial = np.ones(self.horizon)
            regrets_one_trial = np.zeros([self.num_players, self.horizon])
            rewards_one_trial = np.zeros([self.num_players, self.horizon])

            players_es_mean = [np.zeros(self.num_arms) for j in range(self.num_players)]
            players_count = [np.zeros(self.num_arms) for j in range(self.num_players)]


            # Index_estimation 
            indexs = np.ones(self.num_players)*self.num_players-1
            arms = np.zeros(self.num_players)

            At = np.ones(self.num_players)*(-1)
            last_pulled = np.ones(self.num_arms)*(-1)



            for round in range(self.num_players):
                
                random_tie_breaking = self.arms_preferrences.copy()
                for i in range(self.num_arms):
                    for j in range(self.num_players):
                        random_tie_breaking[i][j] += np.random.rand() / 1000
                for j in range(self.num_arms):
                    self.arms_rankings[j] = np.argsort(-random_tie_breaking[j])

                for p_idx in range(self.num_players):
                    At[p_idx] = arms[p_idx]

                At = At.astype(int)
                last_pulled = np.ones(self.num_arms)*(-1)
                for a_idx in range(self.num_arms):
                    if a_idx in At:
                        for p_rank in range(self.num_players):
                            if At[self.arms_rankings[a_idx][p_rank]]==a_idx:
                                last_pulled[a_idx] = self.arms_rankings[a_idx][p_rank]
                                break
                last_pulled = last_pulled.astype(int)
               
                for p_idx in range(self.num_players):
                    if last_pulled[At[p_idx]]==p_idx:
                        regrets_one_trial[p_idx][round]=max(0,self.pessimal_matching[p_idx] - self.players_mean[p_idx][At[p_idx]])
                        rewards_one_trial[p_idx][round] = self.players_mean[p_idx][At[p_idx]]
                        if At[p_idx]==0:
                            indexs[p_idx]=round
                            arms[p_idx] = 1
                    else:
                        regrets_one_trial[p_idx][round] = self.pessimal_matching[p_idx]
                        rewards_one_trial[p_idx][round] = 0
            
            # print(indexs)
            current_player_ranking = [ np.zeros(self.num_arms) for j in range(self.num_players)]
            current_match = np.zeros(self.num_arms)

            for round in range(self.num_players,self.horizon):
                
                random_tie_breaking = self.arms_preferrences.copy()
                for i in range(self.num_arms):
                    for j in range(self.num_players):
                        random_tie_breaking[i][j] += np.random.rand() / 1000
                for j in range(self.num_arms):
                    self.arms_rankings[j] = np.argsort(-random_tie_breaking[j])

                i = math.floor(math.log(round,2))
                # exploration
                if round-2**i+1 <= self.num_arms*math.floor(i**self.varEpsilon):
                    
                    for p_idx in range(self.num_players):
                        At[p_idx] = (round+2+indexs[p_idx]-2**i)%self.num_arms
                    # print("Explore-------round ",round, At)
                    last_pulled = np.ones(self.num_arms)*(-1)
                    for a_idx in range(self.num_arms):
                        if a_idx in At:
                            for p_rank in range(self.num_players):
                                if At[self.arms_rankings[a_idx][p_rank]]==a_idx:
                                    last_pulled[a_idx] = self.arms_rankings[a_idx][p_rank]
                                    break
                    # Here: whether stable matching according to last_pulled.
                    last_pulled = last_pulled.astype(int)
                    unstable_one_trial[round] = self.isUnstable(last_pulled)
                    
                    At = At.astype(int)
                    for p_idx in range(self.num_players):
                        if last_pulled[At[p_idx]]==p_idx:
                            # update
                            # reward = np.random.binomial(1, self.players_mean[p_idx][At[p_idx]])

                            reward=np.random.normal(loc=self.players_mean[p_idx][At[p_idx]], scale=1.0, size=None)


                            players_count[p_idx][At[p_idx]]+=1
                            players_es_mean[p_idx][At[p_idx]]+= (reward-players_es_mean[p_idx][At[p_idx]]) / players_count[p_idx][At[p_idx]]
                            
                            # record
                            regrets_one_trial[p_idx][round]=max(0,self.pessimal_matching[p_idx] - self.players_mean[p_idx][At[p_idx]])
                            rewards_one_trial[p_idx][round] = self.players_mean[p_idx][At[p_idx]]
                        else:
                            regrets_one_trial[p_idx][round] = self.pessimal_matching[p_idx]
                            rewards_one_trial[p_idx][round] = 0
                # commit
                else:
                    for j in range(self.num_players):
                        current_player_ranking[j] = np.argsort(-players_es_mean[j])
                
                    current_match = self.Gale_Shapley(current_player_ranking)
                    
                    # At=np.zeros(self.num_players)

                    # for a_idx,p_idx in enumerate(current_match):
                    #     At[p_idx] = a_idx

                    # print("-----round ",round, At)
                    
                    # Here: whether stable matching according to last_pulled.
                    unstable_one_trial[round] = self.isUnstable(current_match)

                    for a_idx, p_idx in enumerate(current_match):
                        # print("Commit: (",p_idx, a_idx,")" )
                        regrets_one_trial[p_idx][round]=max(0,self.pessimal_matching[p_idx] - self.players_mean[p_idx][a_idx])
                        rewards_one_trial[p_idx][round] = self.players_mean[p_idx][a_idx]
                    
            
            for i in range(self.num_players):
                regrets_trials[i][trial] = regrets_one_trial[i]
                #rewards_trials[i][trial] = rewards_one_trial[i]
            # print(np.sum(regrets_one_trial, axis = 0))
            unstable_trials[trial] = unstable_one_trial
        # np.savez('./ResultsData/Decen_PhasedETC_Beta_'+str(Beta)+'N_'+str(self.num_players)+'_Regret.npz', regret=regrets_trials)
        # np.savez('Decen_PhasedETC_Beta_'+str(Beta)+'N_'+str(self.num_players)+'_Reward.npz', reward=rewards_trials)
        np.savez('PhasedETC_regret.npz', regret=regrets_trials)
        np.savez('PhasedETC_Unstable.npz', unstable=unstable_trials)
        # np.savez('PhasedETC_'+'delta_'+str(delta)+'_regret.npz', regret=regrets_trials)
        # np.savez('PhasedETC_'+'delta_'+str(delta)+'_Unstable.npz', unstable=unstable_trials)
        print(np.sum(unstable_trials))
     
    def run_ETC(self, h, Beta,delta ):
       
        regrets_trials = np.zeros([self.num_players, self.trials, self.horizon])
        rewards_trials = np.zeros([self.num_players, self.trials, self.horizon])
        unstable_trials = np.zeros([self.trials, self.horizon])

        for trial in tqdm(range(self.trials), ascii=True, desc="Running the Decentralized-ETC"):
            
            unstable_one_trial = np.zeros(self.horizon)
            regrets_one_trial = np.zeros([self.num_players, self.horizon])
            rewards_one_trial = np.zeros([self.num_players, self.horizon])
            
            etc_player_es_mean = [np.zeros(self.num_arms) for j in range(self.num_players)]
            etc_players_count = [np.zeros(self.num_arms) for j in range(self.num_players)]

            arm_order = np.zeros([self.num_players, self.num_arms])
            for round in range(self.horizon):

                random_tie_breaking = self.arms_preferrences.copy()
                for i in range(self.num_arms):
                    for j in range(self.num_players):
                        random_tie_breaking[i][j] += np.random.rand() / 1000
                for j in range(self.num_arms):
                    self.arms_rankings[j] = np.argsort(-random_tie_breaking[j])

                # print("----------------TS time step: ",round, "-------------")
                
                if round < h * self.num_arms:
                    
                    step = round % self.num_arms
                    if step ==0:
                        for i in range(self.num_players):
                            for j in range(self.num_arms):
                                arm_order[i][j] = (i + j) % self.num_arms
                    # print(arm_order)
                    At = np.ones(self.num_players)*(-1)
                    last_pulled = np.ones(self.num_arms)*(-1)

                    for p_idx in range(self.num_players):
                        At[p_idx] = arm_order[p_idx][step]

                    At = At.astype(int)
                    
                    for a_idx in range(self.num_arms):
                        if a_idx in At:
                            for p_rank in range(self.num_players):
                                if At[self.arms_rankings[a_idx][p_rank]]==a_idx:
                                    last_pulled[a_idx] = self.arms_rankings[a_idx][p_rank]
                                    break
                    last_pulled = last_pulled.astype(int)
               
                    for p_idx in range(self.num_players):
                        if last_pulled[At[p_idx]]==p_idx:
                            regrets_one_trial[p_idx][round]=max(0,self.pessimal_matching[p_idx] - self.players_mean[p_idx][At[p_idx]])
                            rewards_one_trial[p_idx][round] = self.players_mean[p_idx][At[p_idx]]

                            # reward = np.random.binomial(1, self.players_mean[p_idx][At[p_idx]])
                            reward=np.random.normal(loc=self.players_mean[p_idx][At[p_idx]], scale=1.0, size=None)
                            etc_players_count[p_idx][At[p_idx]]+=1
                            etc_player_es_mean[p_idx][At[p_idx]]+= (reward-etc_player_es_mean[p_idx][At[p_idx]]) / etc_players_count[p_idx][At[p_idx]]
                            
                            
                        else:
                            regrets_one_trial[p_idx][round] = self.pessimal_matching[p_idx]
                            rewards_one_trial[p_idx][round] = 0
                
    
                    unstable_one_trial[round] = 1
                    
                
                # Commit
                else:
                    
                    if round == h * self.num_arms:
                        current_player_ranking = [ np.zeros(self.num_arms) for j in range(self.num_players)]
                        for j in range(self.num_players):
                            current_player_ranking[j] = np.argsort(-etc_player_es_mean[j])
                        current_match = self.Gale_Shapley(current_player_ranking)
                    
                    # self.two_side_market.proceed()
                    for a_idx, p_idx in enumerate(current_match):
                        # reward = np.random.binomial(1, self.players_mean[p_idx][a_idx])
    
                        # etc_players_count[p_idx][a_idx]+=1
                        # etc_player_es_mean[p_idx][a_idx]+= (reward-etc_player_es_mean[p_idx][a_idx]) / etc_players_count[p_idx][a_idx]
                        
                    
                        regrets_one_trial[p_idx][round]=max(0,self.pessimal_matching[p_idx] - self.players_mean[p_idx][a_idx])
                        rewards_one_trial[p_idx][round] = self.players_mean[p_idx][a_idx]
                        
                
                        
                    # Here: whether stable matching according to last_pulled.
                    unstable_one_trial[round] = self.isUnstable(current_match)
                

                # if round==self.horizon-1:
                #    print('--------ETC:', 'Arm matching = ',current_match)    
                    

            for i in range(self.num_players):
                regrets_trials[i][trial] = regrets_one_trial[i]
                rewards_trials[i][trial] = rewards_one_trial[i]

            unstable_trials[trial] = unstable_one_trial


        # np.savez('./ResultsData/Decen_ETC_Beta_'+str(Beta)+'N_'+str(self.num_players)+'h_'+str(h)+'_Regret.npz', regret=regrets_trials)
        
        # np.savez('ResultsData\Decen_ETC_Beta_'+str(Beta)+'N_'+str(self.num_players)+'h_'+str(h)+'_Reward.npz', reward=rewards_trials)
        # np.savez('Decen_ETC_Beta_'+str(Beta)+'N_'+str(self.num_players)+'h_'+str(h)+'_Reward.npz', reward=rewards_trials)
        np.savez('ETC_regret.npz', regret=regrets_trials)
        np.savez('ETC_Unstable.npz', unstable=unstable_trials)
        # np.savez('ETC_'+'delta_'+str(delta)+'_regret.npz', regret=regrets_trials)
        # np.savez('ETC_'+'delta_'+str(delta)+'_Unstable.npz', unstable=unstable_trials)
        # cumulative_unstable = np.cumsum(np.array(unstable_trials), axis=1)
        # print(np.sum(unstable_trials))


        cumulative_regrets=np.cumsum(np.array(regrets_trials), axis=2)
        print(cumulative_regrets)