import numpy as np

import preliminary
import preliminary as pre
import math
def findtime(t_total_collision,K,agent):
    t=0
    if type(agent)==int:
        k=agent
        for j in range(K):
            t = t + t_total_collision[k][j]
    else:
        k=agent[0]
        for j in range(K):
            t=t+t_total_collision[k][j]
    return t
def success_information(leader,follower,arm,success,arm_preference,information,M,t_total,t_total_collision,agent,reward,K,value):
    K_layer = len(arm)
    M_layer = len(leader)
    pull = np.zeros(M, int)-1
    for m in range(M_layer):
        for k in range(K_layer):
          for m1 in range(M):
             for m2 in range(m1 + 1, M):
                for j in leader:
                    if information[j, arm[k], m1] == j:
                        if success[j] == 0:
                            pull[j] = arm[k]
                        else:
                            m = (k + 1) % (K_layer)
                            pull[j] = arm[m]
                    elif information[j, arm[k], m2] == j:
                        pull[j] = arm[k]
                    else:
                        m = (k + 1) % (K_layer)
                        pull[j] = arm[m]
                for i in follower:
                    if information[i, arm[k], m2] == i:
                        pull[i] = arm[k]
                    else:
                        m = (k + 1) % (K_layer)
                        pull[i] = arm[m]
                for j in follower:
                    t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (
                            1 - pre.collision_indicator(j, arm_preference, pull, M, agent))
                    t_total_collision[j, pull[j]] = t_total_collision[j, pull[j]] + 1
                    if information[j, arm[k], m2] == j:
                        if information[j, arm[k], m1] in leader:
                            #print(success[information[j, arm[k], m1]],arm[k],pull)
                            success[j] = 1-pre.collision_indicator(j, arm_preference, pull, M, agent)
                for i in leader:
                    t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (
                            1 - pre.collision_indicator(i, arm_preference, pull, M, agent))
                    t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]] + 1
                    if information[i, arm[k], m2] == i:
                        if success[i] == 1:
                            success[i] = 1-pre.collision_indicator(i, arm_preference, pull, M, agent)
                if findtime(t_total_collision,K,agent)%100000==0:
                        earn = 0
                        for i in range(M):
                            for k in range(K):
                                earn = t_total[i, k] * value[i, k] + earn
                                nod =int((findtime(t_total_collision,K,agent)) / 100000)
                                reward[nod] = earn
    return success, t_total, t_total_collision,reward
'''def success_exchange(leader,follower,arm,success,arm_preference,information,M,t_total,t_total_collision,agent):
    K_layer=len(arm)
    M_layer=len(leader)
    M=len(agent)
    pull=np.zeros(M,int)
    for k in arm:
        for m1 in range(M_layer):
            for m2 in range(m1+1,M_layer):
                for j in leader:
                    if information[j, k, m1] == j:
                          if success[j] ==1:
                            pull[j] =arm [k]
                          else:
                            m= (k + 1) % (K_layer)
                            pull[j] = arm[m]
                    elif information[j,k,m2] == j:
                        pull[j] = arm[k]
                    else:
                        m = (k + 1) % (K_layer)
                        pull[j] = arm[m]
                for j in follower:
                    pull[j]=arm[0]
                for j in follower:
                    t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (
                                1 - pre.collision_indicator(j, arm_preference, pull, M,agent))
                    t_total_collision[j, pull[j]] = t_total_collision[j, pull[j]] + 1
                for i in leader:
                    t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (
                                1 - pre.collision_indicator(i, arm_preference, pull, M,agent))
                    t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]] + 1
                    if information[i, k, m2] == i:
                        if success[i]==1:
                             success[i] = pre.collision_indicator(i, arm_preference, pull, M,leader)
    return success,t_total,t_total_collision
def success_information_1(leader,follower,agent,arm,arm_preference,M,information,success,t_total,t_total_collision):
    K_layer = int(np.shape(arm)[0])
    M_layer=len(agent)
    M_layer_leader = len(leader)
    M_layer_follower = len(follower)
    pull=np.zeros(M,int)
    for t in range(M_layer_follower):
        for i in leader:
            if success[i]==1:
                pull[i]=arm[0]
            else:
                pull[i]=arm[1]
        for j in follower:
            if j==follower[t]:
                pull[j]=arm[0]
            else:
                pull[j]=arm[1]
        for j in follower:
            t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (1 - pre.collision_indicator(j, arm_preference, pull, M,agent))
            t_total_collision[j, pull[j]] = t_total_collision[j, pull[j]] + 1
            if j == follower[t]:
                success[j]=pre.collision_indicator(j,arm_preference,pull,M,agent)
        for i in leader:
            t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (1 - pre.collision_indicator(i, arm_preference, pull, M,agent))
            t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]] + 1
    return success,t_total,t_total_collision'''
def GS_arm(information,u,leader,arm_preference,arm,follower,M,agent,K,t_total,t_total_collision,reward,value):
    K_layer = int(np.shape(arm)[0])
    M_layer=int(np.shape(agent)[0])
    M_layer_leader = int(np.shape(leader)[0])
    M_layer_follower = int(np.shape(follower)[0])
    estimation=np.zeros((M,K),int)
    lenth1=M_layer_leader**2
    pull=np.zeros(M,int)-1
    optimal=np.zeros(M,int)
    lenth2=M_layer_follower
    arm_left=list()
    for a in arm:
        arm_left.append(a)
    print(arm_left)
    for i in leader:
        estimation[i,:]=u[i,:].argsort()[::-1]
        optimal[i]=0
    for t in range (lenth1):
        for i in leader:
            while estimation[i,optimal[i]] not in arm:
                optimal[i]=optimal[i]+1
            pull[i]=estimation[i,optimal[i]]
        for j in follower:
            pull[j]=arm[0]
        for j in follower:
            t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (1 - pre.collision_indicator(j, arm_preference, pull, M,agent))
            t_total_collision[j, pull[j]] = t_total_collision[j, pull[j]] + 1
        for i in leader:
            c=preliminary.collision_indicator(i,arm_preference,pull,M,agent)
            t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (1 - pre.collision_indicator(i, arm_preference, pull, M,agent))
            t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]] + 1
            if c==1:
                optimal[i]=optimal [i]+1
        if findtime(t_total_collision,  K,agent) % 100000 == 0:
            earn = 0
            for i in range(M):
                for k in range(K):
                    earn = t_total[i, k] * value[i, k] + earn
                    nod = int((findtime(t_total_collision,K,agent)) / 10000)
                    reward[nod] = earn
    for k in range(K_layer):
        for t in range (lenth2):
            for i in leader:
                pull[i]=estimation[i,optimal[i]]
            for j in follower:
                if j==follower[t]:
                    pull[j]=arm[k]
                else:
                    m=(k+1)%K_layer
                    pull[j]=arm[m]
            for i in leader:
                t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (
                            1 - pre.collision_indicator(i, arm_preference, pull, M,agent))
                t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]] + 1
            for j in follower:
                t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (
                            1 - pre.collision_indicator(j, arm_preference, pull, M,agent))
                t_total_collision[j, pull[j]] = t_total_collision[j, pull[j]] + 1
                if j==follower[t]:
                    c=preliminary.collision_indicator(j,arm_preference,pull,M,agent)
                    if c==1:
                        if arm[k] in arm_left:
                            arm_left.remove(arm[k])
                if findtime(t_total_collision, K,agent) % 100000 == 0:
                    earn = 0
                    for i in range(M):
                        for k in range(K):
                            earn = t_total[i, k] * value[i, k] + earn
                            nod = int((findtime(t_total_collision, K,agent)) / 100000)
                            reward[nod] = earn
    return pull,arm_left,t_total,t_total_collision,reward