

import numpy as np
import pickle
import networkx as nx
import matplotlib.pyplot as plt
import random 
from time import perf_counter
#random.seed(43)

def add_random_edges(graph, probability):
    nodes = list(graph.nodes())
    new_edges = [(u, v) for u in nodes for v in nodes if u != v and random.random() < probability]
    graph.add_edges_from(new_edges)
    return graph

def generateNetwork(M,N,p):
    """
    M := Number of teams
    N := Number of players in each team
    """ 
    G = nx.Graph()

    # Add nodes with type information
    numNodes = 0
    colorMap = []
    for m in range(M):
        G.add_nodes_from([(i+numNodes, {'type': m}) for i in range(N)])
        numNodes += N
        teamColor = "#%06x" % random.randint(0, 0xFFFFFF)
        for j in range(N):
            colorMap.append(teamColor)

    G = add_random_edges(G,p)
    pos = nx.spring_layout(G)
    node_labels = {node[0]: node[0] for node in G.nodes(data=True)}
    nx.draw(G, pos, with_labels=True, labels=node_labels, font_weight='bold', node_color=colorMap)
    plt.savefig("currentGraphs.png")
    plt.show()

    adjacencyMatrix = nx.adjacency_matrix(G).todense()
    RAM = adjacencyMatrix.copy()

    for m in range(M):
        teamNodes = [node[0] for node in G.nodes(data=True) if node[1]['type']==m]
        enemyNodes = [node[0] for node in G.nodes(data=True) if node[1]['type']!=m]

        for edge in G.edges():
            if (edge[0] in teamNodes) and (edge[1] in enemyNodes):
                RAM[edge[0],edge[1]] = -1
                RAM[edge[1],edge[0]] = -1
    
    RAM += np.eye(M*N,dtype=RAM.dtype)

    return G, RAM

def generate_random_stochastic_matrix(n):
    """
    Generate a random n x n stochastic matrix.
    """
    matrix = np.random.rand(n, n)
    matrix /= matrix.sum(axis=1, keepdims=True)
    return matrix

def generate_N_random_stochastic_matrices(N, n):
    """
    Generate N random n x n stochastic matrices and stack them in a list.
    """
    stochastic_matrices = [generate_random_stochastic_matrix(n) for _ in range(N)]
    return np.array(stochastic_matrices)

def generateEdgeUtilities(G,A):
    """Given a graph, G, and number of actions for each player A, this function generates edge utilities."""
    nodeTypes = nx.get_node_attributes(G,'type')
    utilities = []
    edgeUtilities = {}
    for edge in G.edges:
        edgeUtilities[edge] = 0
        edgeUtilities[edge[1],edge[0]] = 0
    for i in G.nodes:
        neighbours = list(G.neighbors(i))
        nodeUtility = np.zeros(A**(len(neighbours)+1)) #A utility function for each node
        actProfilesList = np.array(getActProfileList(A,(len(neighbours)+1)))

        team = nodeTypes[i]
        teamNeighbours = [m for m in G.neighbors(i) if nodeTypes[m]==team]
        teamNeighbours.append(i)
        teamNeighbours.sort()
        enemyNeighbours = [m for m in G.neighbors(i) if nodeTypes[m]!=team]
        teamNeighbours = np.array(teamNeighbours)
        enemyNeighbours = np.array(enemyNeighbours)

        neighbours = np.concatenate((teamNeighbours,enemyNeighbours))
        neighbours.sort()
        neighbours = neighbours.astype(int)

        localTeamIdx = np.nonzero(teamNeighbours[:,None] == neighbours)[1]
        localEnemyIdx = np.nonzero(enemyNeighbours[:, None] == neighbours)[1]
        localSelfIdx = np.where(neighbours==i)[0]

        u_ij_list = []
        for j in enemyNeighbours:
            enemyTeamId = nodeTypes[j]
            enemyTeamNeighbours = np.array([m for m in enemyNeighbours if nodeTypes[m]==enemyTeamId])
            localEnemyTeamIdx = np.nonzero(enemyTeamNeighbours[:, None] == neighbours)[1]
            u_ij = np.random.rand(A**(len(enemyTeamNeighbours)+len(teamNeighbours)))
            edgeUtilities[(i,j)] = u_ij

        for actProfile in actProfilesList:
            actId = getActProfileIdx(actProfile,A)
            for j in enemyNeighbours:
                enemyTeamId = nodeTypes[j]
                enemyTeamNeighbours = np.array([m for m in enemyNeighbours if nodeTypes[m]==enemyTeamId])
                localEnemyTeamIdx = np.nonzero(enemyTeamNeighbours[:, None] == neighbours)[1]
                pairwiseActProfile = [actProfile[actor] for actor in range(len(actProfile)) if (actor in localTeamIdx) or (actor in localEnemyTeamIdx)]
                nodeUtility[actId] += edgeUtilities[(i,j)][getActProfileIdx(pairwiseActProfile,A)]
        
        utilities.append(nodeUtility)
    
    return utilities, edgeUtilities    

def globalToLocalActProfile(RAM,actProfile):
    localActProfiles = []
    for rowId in range(RAM.shape[0]):
        localActProfile = actProfile[np.where(np.abs(RAM[rowId,:])!= 0)[0]]
        localActProfiles.append(localActProfile)
    return localActProfiles

def neighbourGlobalToLocalIndex(agent_i, team_i, enemy_i):
    #agent_i should be within team_i.
    #team_i should be neighbours of agent_i who are on the same team.
    #enemy_i should be enemy neighbours of agent_i.
    #Both lists should be sorted.

    if type(team_i) == list or type(enemy_i) == list:
        team_i = np.array(team_i)
        enemy_i = np.array(enemy_i)

    neighbours = np.concatenate((team_i,enemy_i))
    neighbours.sort()
    neighbours = neighbours.astype(int)

    localTeamIdx = np.nonzero(team_i[:,None] == neighbours)[1]
    localEnemyIdx = np.nonzero(enemy_i[:, None] == neighbours)[1]
    localSelfIdx = np.where(neighbours==agent_i)[0]

    return localTeamIdx, localEnemyIdx, localSelfIdx

def generatePotentialFunctions(G,RAM,utility,M,N,A):
    potentialFunctions = []
    nodeNumber = len(utility)
    noJointActProfiles = A**(M*N)
    nodeTypes = nx.get_node_attributes(G,'type')
    for m in range(M):
        phi_m = np.zeros(noJointActProfiles)
        for a in range(noJointActProfiles):
            globalActProfile = np.array(getActProfile(a,A,nodeNumber))
            localActProfiles = globalToLocalActProfile(RAM,globalActProfile)
            u = 0
            for edge in G.edges:
                i = edge[0]
                j = edge[1]
                if (nodeTypes[i] == m) != (nodeTypes[j] == m):
                    localActProfile = localActProfiles[i]
                    i_team = [l for l in G.neighbors(i) if nodeTypes[l] == nodeTypes[i]]
                    j_team = [l for l in G.neighbors(i) if nodeTypes[l] == nodeTypes[j]]
                    i_team.append(i)
                    i_team.sort()

                    pairwiseActProfile = [globalActProfile[actor] for actor in range(len(globalActProfile)) if (actor in i_team) or (actor in j_team)]

                    if nodeTypes[i] == m:
                        u += edgeUtilities[i,j][getActProfileIdx(pairwiseActProfile,A)]
                    if nodeTypes[j] == m:
                        u -= edgeUtilities[i,j][getActProfileIdx(pairwiseActProfile,A)]
            
            for edge in G.edges:
                i = edge[1]
                j = edge[0]
                if (nodeTypes[i] == m) != (nodeTypes[j] == m):
                    localActProfile = localActProfiles[i]
                    i_team = [l for l in G.neighbors(i) if nodeTypes[l] == nodeTypes[i]]
                    j_team = [l for l in G.neighbors(i) if nodeTypes[l] == nodeTypes[j]]
                    i_team.append(i)
                    i_team.sort()

                    pairwiseActProfile = [globalActProfile[actor] for actor in range(len(globalActProfile)) if (actor in i_team) or (actor in j_team)]

                    if nodeTypes[i] == m:
                        u += edgeUtilities[i,j][getActProfileIdx(pairwiseActProfile,A)]
                    if nodeTypes[j] == m:
                        u -= edgeUtilities[i,j][getActProfileIdx(pairwiseActProfile,A)]
                    
            phi_m[a] = u
        potentialFunctions.append(phi_m.reshape((A**N,)*M))

    return potentialFunctions

def entropy(pi):
    logits = np.log(pi)
    logits[np.where(pi < 1e-14)[0]] = 0
    return -np.sum(pi*logits)

def teamNashGap(potentialFunctions,Pi):
    NG = 0
    for m in range(len(potentialFunctions)):
        teamPotentialFunction = potentialFunctions[m]
        pi_m = Pi[m,:]
        for n in range(m):
            teamPotentialFunction = np.einsum('i...,i',teamPotentialFunction,Pi[n,:])
        for n in range(len(potentialFunctions)-m-1):
            teamPotentialFunction = (teamPotentialFunction @ Pi[len(potentialFunctions)-n-1,:]).squeeze()
        bestResponseValue = max(teamPotentialFunction)
        #if m == 0: #only for settin 7. Remove it afterwards.
        NG += (bestResponseValue - pi_m@teamPotentialFunction)
    return NG

def teamNashGap8(randomRewards,potentialFunction1,Pi,N):
    NG = 0
    expected = randomRewards @ Pi[1]
    NG += np.max(expected) - Pi[0]@expected

    expected = Pi[0]@potentialFunction1
    NG += np.max(expected) - expected @ Pi[1]

    return NG

# def nashGap(utility,PiAvg,A,G):
#     NG = 0
#     for m in range(len(utility)):
#         neighbours = list(G.neighbors(m))
#         neighbours.sort()
#         noNeighbors = len(neighbours)+1
#         utilityMatrix = utility[m].reshape((A,)*noNeighbors)
#         pi_m = PiAvg[m,:]

#         for n in range(m):
#             if n in neighbours:
#                 utilityMatrix = np.einsum('i...,i',utilityMatrix,PiAvg[n,:])
#         for n in range(len(utility)-m-1):
#             if len(utility)-n-1 in neighbours:
#                 utilityMatrix = (utilityMatrix @ PiAvg[len(utility)-n-1,:]).squeeze()

#         bestResponseValue = np.max(utilityMatrix)
#         NG += (bestResponseValue - pi_m @ utilityMatrix)
    
#     return NG

def nashGap(G,RAM,utility,allRewards,allActProfiles,PiAvg,m,n,M,N,A,perspectivePlayer=None):
    NG = 0
    for m in range(M):
        for n in range(N):
            #expPayoffs = expectedPayoffMWU(G,RAM,utility,PiAvg,m,n,M,N,A,perspectivePlayer=None)
            expPayoffs = expectedRewardsForMWU_FP(allRewards[m*N+n],allActProfiles,PiAvg,m,n,M,N,A)
            pi_mn = PiAvg[m*N+n,:]
            bestResponseValue = np.max(expPayoffs)
            NG += (bestResponseValue - pi_mn @ expPayoffs.squeeze())
        
    return NG


#region conversion functions
def indexToTeamAndPlayer(playerIdx,N):
    m = playerIdx // N
    n = playerIdx % N
    return m,n

def indexToTeamAndPlayerList(playerIdxList,N):
    players = []
    for playerIdx in playerIdxList:
        players.append(indexToTeamAndPlayer(playerIdx,N))
    return players

def teamAndPlayerToIndex(m,n,N):
    playerIdx = m*N + n
    return playerIdx

def teamAndPlayerToIndexList(players,N):
    playerIdxList = []
    for player in players:
        playerIdxList.append(teamAndPlayerToIndex(player[0],player[1],N))
    return playerIdxList

def getActProfileIdx(actProfile,A):
    if type(actProfile) == list or len(actProfile.shape)==1:
        return int(np.sum([actProfile[-i-1]*(A**(i)) for i in range(len(actProfile))]))
    elif len(actProfile.shape)==2: #The second dimension is for actions, and the first dimension is the list length.
        return int(np.sum([actProfile[:,-i-1]*(A**(i)) for i in range(actProfile.shape[1])],axis=0))
    else:
        raise("ERROR: Action Profiles should be 1 or 2 dimensional array")

def getActProfile(actProfileIdx,A,N):
    digits = []
    for j in range(N):
        digits.append(int(actProfileIdx % A))
        actProfileIdx //= A
    return digits[::-1]

def getActProfileList(A,N):
    return [getActProfile(item,A,N) for item in range(A**N)]
#endregion

def expectedPayoff(G,RAM,utility,Pi,lastA,m,n,M,N,A,perspectivePlayer=None):
    """
    Calculates the expected payoff for player n from team m.
    """
    nodeTypes = nx.get_node_attributes(G,'type')
    selfGlobalIdx = m*N+n
    selfTeam = nodeTypes[selfGlobalIdx]
    if perspectivePlayer == None:
        perspectivePlayer = selfGlobalIdx

    neighbours = RAM[perspectivePlayer,:]
    neighbourIdxs = np.where(neighbours!=0)[0]
    teamNeighbours = np.array([neighbourIdx for neighbourIdx in neighbourIdxs if nodeTypes[neighbourIdx] == selfTeam ])
    enemyNeighbours = np.array([neighbourIdx for neighbourIdx in neighbourIdxs if nodeTypes[neighbourIdx] != selfTeam ])


    neighbours = np.concatenate((teamNeighbours,enemyNeighbours))
    neighbours.sort()
    neighbours = neighbours.astype(int)

    localTeamIdx = np.nonzero(teamNeighbours[:,None] == neighbours)[1]
    localEnemyIdx = np.nonzero(enemyNeighbours[:, None] == neighbours)[1]
    localSelfIdx = np.where(neighbours==selfGlobalIdx)[0]

    expectedUtilityForPerspectivePlayer = np.zeros((A,1))
    expectedRewardForSelf = np.zeros((A,1))
    Pi_M = Pi.copy()

    actProfilesList = np.array(getActProfileList(A,N))
    for a in range(A):
        payoff = 0
        neighbourActProfile = np.zeros(len(neighbours),dtype=np.int32)
        neighbourActProfile[localTeamIdx] = lastA[m,teamNeighbours-m*N]
        neighbourActProfile[localSelfIdx] = a
        for j in range(A**len(enemyNeighbours)):
            enemyNeighboursActProfile = np.array(getActProfile(j,A,len(enemyNeighbours)))
            neighbourActProfile[localEnemyIdx] = enemyNeighboursActProfile
            enemyActProfileProb = 1

            for team in range(M):
                if team == m: #Player's own team
                    pass
                else: #enemy teams
                    teamEnemyNeighboursIdx = np.where((team*N<=enemyNeighbours) & (enemyNeighbours<(team+1)*N))[0]
                    if len(teamEnemyNeighboursIdx) != 0: 
                        teamEnemyNeighbours = enemyNeighbours[teamEnemyNeighboursIdx]
                        teamEnemyNeighboursActProfile = enemyNeighboursActProfile[teamEnemyNeighboursIdx]

                        # for j in range(A**len(teamEnemyNeighbours)):
                        #     enemyNeighboursActProfile = np.array(getActProfile(j,A,len(teamEnemyNeighbours)))
                        teamEnemyActProfileProb = Pi[team,np.where((actProfilesList[:,teamEnemyNeighbours-team*N]==teamEnemyNeighboursActProfile).all(1))[0]]
                        teamEnemyActProfileProb = np.sum(teamEnemyActProfileProb) #Marginalize the belief for local information.
                        enemyActProfileProb *= teamEnemyActProfileProb                        

            neighbourActProfileIdx = getActProfileIdx(neighbourActProfile,A)
            #for pairwise zero-sum idea:
            if nodeTypes[perspectivePlayer] == selfTeam:
                payoff += utility[perspectivePlayer][neighbourActProfileIdx] * enemyActProfileProb
            else:
                perspectivePlayerTeam = np.array([m for m in neighbours if nodeTypes[m]==nodeTypes[perspectivePlayer]])
                perspectivePlayerTeamLocalIdx = np.nonzero(perspectivePlayerTeam[:, None] == neighbours)[1]
                pairwiseActProfile = [neighbourActProfile[actor] for actor in range(len(neighbourActProfile)) if (actor in localTeamIdx) or (actor in perspectivePlayerTeamLocalIdx)]
                for enemyNeighbourMyTeamNode in teamNeighbours:               
                    payoff += edgeUtilities[(perspectivePlayer,enemyNeighbourMyTeamNode)][getActProfileIdx(pairwiseActProfile,A)] * enemyActProfileProb

        expectedUtilityForPerspectivePlayer[a] = payoff
    if perspectivePlayer != selfGlobalIdx:
        return expectedUtilityForPerspectivePlayer
    else:
        for neighbourNode in neighbours:
            if neighbourNode == selfGlobalIdx:
                expectedRewardForSelf += expectedUtilityForPerspectivePlayer
            else:
                expectedUtilityFromNeighbour = expectedPayoff(G,RAM,utility,Pi,lastA,m,n,M,N,A,perspectivePlayer=neighbourNode)
                if neighbourNode in teamNeighbours:
                    expectedRewardForSelf += expectedUtilityFromNeighbour
                else: #node is from enemy neighbours.
                    expectedRewardForSelf -= expectedUtilityFromNeighbour                

        return expectedRewardForSelf
    
def expectedPotential(G,RAM,potentialFunctions,Pi,lastA,m,n,M,N,A,perspectivePlayer=None,groups=1):
    """
    Calculates the expected payoff for player n from team m.
    """
    teamPotentialFunction = potentialFunctions[m]
    pi_m = Pi[m,:]
    for j in range(m):
        teamPotentialFunction = np.einsum('i...,i',teamPotentialFunction,Pi[j,:])
    for j in range(len(potentialFunctions)-m-1):
        teamPotentialFunction = (teamPotentialFunction @ Pi[len(potentialFunctions)-j-1,:]).squeeze()

    expectedPotentialForSelf = np.zeros((A**groups,1))
    actProfile = lastA[m].copy()
    for a in range(A**groups):
        groupActProfile = getActProfile(a,A,groups)
        actProfile[groups*(n//groups):groups*(n//groups)+groups] = groupActProfile
        actProfileIdx = getActProfileIdx(actProfile,A)
        expectedPotentialForSelf[a] = teamPotentialFunction[actProfileIdx]

    return expectedPotentialForSelf
    

def expectedPayoffMWU(G,RAM,utility,PiInd,m,n,M,N,A,perspectivePlayer=None):
    """
    Calculates the expected payoff for player n from team m.
    """
    nodeTypes = nx.get_node_attributes(G,'type')
    selfGlobalIdx = m*N+n
    selfTeam = nodeTypes[selfGlobalIdx]
    if perspectivePlayer == None:
        perspectivePlayer = selfGlobalIdx

    neighbours = RAM[perspectivePlayer,:]
    neighbourIdxs = np.where(neighbours!=0)[0]
    teamNeighbours = np.array([neighbourIdx for neighbourIdx in neighbourIdxs if nodeTypes[neighbourIdx] == selfTeam ])
    enemyNeighbours = np.array([neighbourIdx for neighbourIdx in neighbourIdxs if nodeTypes[neighbourIdx] != selfTeam ])

    perspectiveTeamEnemyNeighbours = np.array([i for i in enemyNeighbours if nodeTypes[i]==nodeTypes[perspectivePlayer]])

    neighbours = np.concatenate((teamNeighbours,enemyNeighbours))
    neighbours.sort()
    neighbours = neighbours.astype(int)


    expectedRewardForSelf = np.zeros((1,A))
    if nodeTypes[perspectivePlayer] == selfTeam:  

        utilityMatrix = utility[perspectivePlayer].reshape((A,)*len(neighbours)).copy() #Convert utilities to matrix form for easy multipilication with beliefs.

        for j in range(selfGlobalIdx):
            if j in neighbours:
                utilityMatrix = np.einsum('i...,i',utilityMatrix,PiInd[j,:])
        for j in range(M*N-selfGlobalIdx-1):
            if M*N-j-1 in neighbours:
                utilityMatrix = (utilityMatrix @ PiInd[M*N-j-1,:]).squeeze()

        expectedUtilityForPerspectivePlayer = utilityMatrix
    else:
        relatedPlayers = np.concatenate((perspectiveTeamEnemyNeighbours,teamNeighbours))
        payoff = np.zeros_like(expectedRewardForSelf)
        for edge in edgeUtilities.keys():
            if edge[0] == perspectivePlayer and nodeTypes[edge[1]] == selfTeam:
                utilityMatrix = edgeUtilities[edge].reshape((A,)*len(relatedPlayers)).copy()
                for j in range(selfGlobalIdx):
                    if j in relatedPlayers:
                        utilityMatrix = np.einsum('i...,i',utilityMatrix,PiInd[j,:])
                for j in range(M*N-selfGlobalIdx-1):
                    if M*N-j-1 in relatedPlayers:
                        utilityMatrix = (utilityMatrix @ PiInd[M*N-j-1,:]).squeeze()
                payoff += utilityMatrix
    
        expectedUtilityForPerspectivePlayer = payoff

    if perspectivePlayer != selfGlobalIdx:
        return expectedUtilityForPerspectivePlayer
    else:
        for neighbourNode in neighbours:
            if neighbourNode == selfGlobalIdx:
                expectedRewardForSelf += expectedUtilityForPerspectivePlayer
            else:
                expectedUtilityFromNeighbour = expectedPayoffMWU(G,RAM,utility,PiInd,m,n,M,N,A,perspectivePlayer=neighbourNode)
                if neighbourNode in teamNeighbours:
                    expectedRewardForSelf += expectedUtilityFromNeighbour
                else: #node is from enemy neighbours.
                    expectedRewardForSelf -= expectedUtilityFromNeighbour                

        return expectedRewardForSelf
    
#region Consensus
def generateConsensusMatrix(RAM,M,N):
    Cs = []
    for m in range(M):
        C = RAM[m*N:m*N+N,m*N:m*N+N]
        C = normalize_double_stochastic(C)
        Cs.append(C)
    return Cs #return the doubly stochastic consensus matrices.

def normalize_double_stochastic(A, tol=1e-9, max_iter=1000):
    """
    Normalize a matrix to be doubly stochastic using Sinkhorn-Knopp algorithm.
    
    Parameters:
    A (numpy.ndarray): The input matrix to be normalized.
    tol (float): Tolerance for convergence.
    max_iter (int): Maximum number of iterations.
    
    Returns:
    numpy.ndarray: The doubly stochastic matrix.
    """
    n, m = A.shape
    for _ in range(max_iter):
        # Normalize rows
        A = A / A.sum(axis=1, keepdims=True)
        # Normalize columns
        A = A / A.sum(axis=0, keepdims=True)
        
        # Check for convergence
        if np.allclose(A.sum(axis=1), 1, atol=tol) and np.allclose(A.sum(axis=0), 1, atol=tol):
            break
            
    return A

def generateRandomTrajectory(G,utility,M,N,A,L=1000):
    allActionList = getActProfileList(A,M*N)
    allRewards = []
    allActProfiles = []
    for l in range(L):
        rewards = np.zeros(M*N)
        actProfile = np.array(random.choice(allActionList))
        localActProfiles = globalToLocalActProfile(RAM,actProfile)
        allActProfiles.append(actProfile)
        #actionIndex = getActProfileIdx(localActProfiles[],A)
        for m in range(M):
            for n in range(N):
                idx = m*N+n
                neighbours = list(G.neighbors(idx))
                rewards[idx] += sum([utility[m*N+i][getActProfileIdx(localActProfiles[m*N+i],A)] for i in range(N) if (m*N+i in neighbours or m*N+i ==idx)])
                rewards[idx] -= sum([utility[j][getActProfileIdx(localActProfiles[j],A)] for j in range(M*N) if (j in neighbours and (j<m*N or j>=m*N+N))])
        allRewards.append(rewards)
    return np.array(allRewards), np.array(allActProfiles)

def initializeLinearQEstimates(G,M,N,A):
    Phi = np.eye(A**(M*N),A**(M*N))
    thetas = []
    for m in range(M):
        for n in range(N):
            theta = np.zeros(A**(M*N))
            thetas.append(theta)
    
    return Phi, np.array(thetas)

def DIGing(thetas,alpha,Cs,allRewards,allActProfiles,TrajectoryL,L=1000):
    grad = (1/TrajectoryL)*2*((thetas@Phi)[:,getActProfileIdx(allActProfiles,A)] - allRewards.T)@Phi[:,getActProfileIdx(allActProfiles,A)].T
    gamma = grad
    for l in range(L):
        newthetas = thetas.copy()
        newgamma = gamma.copy()
        for m in range(M):
            for n in range(N):
                idx = m*N+n
                neighbours = list(G.neighbors(idx)) + [idx]
                newthetas[idx] = sum([Cs[m][n,i]*thetas[m*N+i] for i in range(N) if (m*N+i in neighbours)]) - alpha*gamma[idx,:]
                newgamma[idx] = sum([Cs[m][n,i]*gamma[m*N+i] for i in range(N) if (m*N+i in neighbours)]) + (1/TrajectoryL)*2*((newthetas[idx,:]@Phi)[getActProfileIdx(allActProfiles,A)]-allRewards[:,idx])@Phi[:,getActProfileIdx(allActProfiles,A)].T - (1/TrajectoryL)*2*((thetas[idx,:]@Phi)[getActProfileIdx(allActProfiles,A)]-allRewards[:,idx])@Phi[:,getActProfileIdx(allActProfiles,A)].T

        gamma = newgamma
        thetas = newthetas
    return thetas
#endregion

def MWUpdate(PiInd,expectedPayoffs,eta=0.1):
    PiInd = PiInd*np.exp(eta*expectedPayoffs)/np.sum(PiInd*np.exp(eta*expectedPayoffs),axis=1).reshape(PiInd.shape[0],1)
    return PiInd

def softmax(x,tau):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x/tau) / np.sum(np.exp(x/tau), axis=0)

def receivedPayoff(R,lastA,M,A):
    for m in range(M):
        actProfileIdx = getActProfileIdx(lastA[m,:],A)
        R = R[actProfileIdx,...].squeeze()
    return R #reward of every player

def calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau):
    payoffs = expectedPayoff(G,RAM,utility,Pi,lastA,m,n,M,N,A)
    return softmax(payoffs,tau)

def updatePi(Pi,lastA,k,M,N,A):
    for m in range(M):
        actProfileIdx = getActProfileIdx(lastA[m,:],A)
        actProfile = np.zeros(A**N)
        actProfile[actProfileIdx] = 1
        Pi[m,:] += (1/(k+1))*(actProfile-Pi[m,:])
    return Pi

def updatePiforTeam(Pi,lastA,k,M,N,A):
    actProfileIdx = getActProfileIdx(lastA,A)
    actProfile = np.zeros(A**N)
    actProfile[actProfileIdx] = 1
    Pi += (1/(k+1))*(actProfile-Pi)
    return Pi

def updatePiAvg(PiAvg,lastA,k,M,N,A):
    for m in range(M):
        for n in range(N):
            actProfile = np.zeros(A)
            actProfile[lastA[m,n]] = 1
            PiAvg[m*N+n,:] += (1/(k+1))*(actProfile-PiAvg[m*N+n,:])
    return PiAvg

def calculateAllRewards(G,RAM,utility,edgeUtilities,M,N,A):
    #only works for 2-teams
    allActProfiles = getActProfileList(A,M*N)
    allRewards = np.zeros((M*N,len(allActProfiles)))
    for actProfileIdx in range(len(allActProfiles)):
        actProfile = np.array(allActProfiles[actProfileIdx])
        rewards = np.zeros(M*N)
        localActProfiles = globalToLocalActProfile(RAM,actProfile)
        allActProfiles.append(actProfile)
        #actionIndex = getActProfileIdx(localActProfiles[],A)
        for m in range(M):
            for n in range(N):
                idx = m*N+n
                neighbours = list(G.neighbors(idx))
                rewards[idx] += sum([utility[m*N+i][getActProfileIdx(localActProfiles[m*N+i],A)] for i in range(N) if (m*N+i in neighbours or m*N+i ==idx)])
                rewards[idx] -= sum([utility[j][getActProfileIdx(localActProfiles[j],A)] for j in range(M*N) if (j in neighbours and (j<m*N or j>=m*N+N))])
        allRewards[:,actProfileIdx] = rewards
    return allRewards

def expectedRewardsForMWU_FP(rewards,allActProfiles,PiInd,m,n,M,N,A):
    
    expectedReward = np.zeros(A)
    # Get the number of rows and columns in A
    rows, cols = PiInd.shape

    # # Generate all possible combinations of indices
    # indices = np.array(allActProfiles)

    # Create an array to store the products
    otherIndices = np.delete(np.arange(rows),m*N+n)
    products = np.prod(PiInd[otherIndices, allActProfiles[:,otherIndices]], axis=1)
    for a in range(A):
        expectedReward[a] = rewards@(products*(allActProfiles[:,m*N+n]==a))


        # for actProfileIdx in range(len(allActProfiles)):
        #     actProfile = allActProfiles[actProfileIdx]
        #     probabilities = PiInd[np.arange(len(actProfile)),actProfile]
        #     if actProfile[m*N+n] == a:
        #         probabilities[m*N+n] = 1
        #     else:
        #         probabilities[m*N+n] = 0
        #     expectedReward[a]+=allRewards[m*N+n,actProfileIdx]*np.prod(probabilities)
    return expectedReward
            
def expectedRewardForSetting8(randomRewards,lastA,Pi,m,n,M,N,A):
    if m == 0:
        expectedReward = randomRewards @ Pi[1]
    else:
        expectedReward = np.zeros(A)
        for a in range(A):
            actProfile = lastA[m]
            actProfile[n] = a
            actIdx = int(getActProfileIdx(actProfile,A))
            expectedReward[a] = Pi[0]@randomRewards[:,actIdx]
    return expectedReward


def syncTeamFPNetowrk(G,RAM,utility,Pi,lastA,M,N,A,T=100000,tau=0.1,groups=1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    for t in range(T):
        for m in range(M):
            n = np.random.choice(np.arange(N)) #choose n th player randomly
            #mu = calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau)

            muP = softmax(expectedPotential(G,RAM,potentialFunctions,Pi,lastA,m,n,M,N,A,groups=groups),tau)
            if groups == 1:
                a = np.random.choice(range(A),p=muP.squeeze())
                lastA[m,n] = a #change the action of the chosen player
            else:
                a = np.random.choice(range(A**groups),p=muP.squeeze())
                actProfile = getActProfile(a,A,groups)
                lastA[m,groups*(n//groups):groups*(n//groups)+groups] = actProfile


        Pi = updatePi(Pi,lastA,k,M,N,A)
        k += 1
        
        if t%10 == 0:
            NG = teamNashGap(potentialFunctions,Pi)
            NGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(Pi[:,:,None].copy())

        if t% 10000 == 0:
            print(NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)

    return PiHistory, NGHistory

def syncTeamFPNetowrk7(G,RAM,utility,Pi,lastA,M,N,A,T=100000,tau=0.1,groups=1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    
    # Create stationary policies for teams.
    sPi = np.random.rand(M,A**N)
    sPi = sPi / np.sum(sPi,axis=1,keepdims=True)

    for t in range(T):
        for m in range(M):
            if m == 0:
                n = np.random.choice(np.arange(N)) #choose n th player randomly
                #mu = calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau)

                muP = softmax(expectedPotential(G,RAM,potentialFunctions,Pi,lastA,m,n,M,N,A,groups=groups),tau)
                if groups == 1:
                    a = np.random.choice(range(A),p=muP.squeeze())
                    lastA[m,n] = a #change the action of the chosen player
                else:
                    a = np.random.choice(range(A**groups),p=muP.squeeze())
                    actProfile = getActProfile(a,A,groups)
                    lastA[m,groups*(n//groups):groups*(n//groups)+groups] = actProfile
            else:
                actProfileIdxTeam = np.random.choice(range(A**N),p=sPi[m,:])
                actProfileTeam = getActProfile(actProfileIdxTeam,A,N)
                for n in range(N):
                    lastA[m,n] = actProfileTeam[n]


        Pi = updatePi(Pi,lastA,k,M,N,A)
        k += 1
        
        if t%10 == 0:
            NG = teamNashGap(potentialFunctions,Pi)
            NGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(Pi[:,:,None].copy())

        if t% 10000 == 0:
            print(NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)

    return PiHistory, NGHistory


def syncTeamFPNetowrk8(G,RAM,utility,Pi,lastA,M,N,A,T=100000,tau=0.1,groups=1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    
    # Create stationary policies for teams.
    Pi = []
    Pi.append(np.zeros(A))
    Pi.append(np.zeros(A**N))
    lastA = []
    lastA.append(np.array([0]))
    lastA.append(np.zeros(N))
    allRewards = calculateAllRewards(G,RAM,utility,edgeUtilities,M,N,A)
    potentialFunction1 = potentialFunctions[1][[0,1],:]
    randomRewards0 = np.random.rand(A,A**N)
    rewardsTeam1 = allRewards[N:,].reshape(N,A**N,A**N)[:,[0,1],:]
    for t in range(T):
        for m in range(M):
            if m != 0:
                n = np.random.choice(np.arange(N)) #choose n th player randomly
                #mu = calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau)

                muP = softmax(expectedRewardForSetting8(rewardsTeam1[n,:,:],lastA,Pi,m,n,M,N,A),tau)
                if groups == 1:
                    a = np.random.choice(range(A),p=muP.squeeze())
                    lastA[m][n] = a #change the action of the chosen player
                else:
                    a = np.random.choice(range(A**groups),p=muP.squeeze())
                    actProfile = getActProfile(a,A,groups)
                    lastA[m][groups*(n//groups):groups*(n//groups)+groups] = actProfile
            else:
                muP = softmax(expectedRewardForSetting8(randomRewards0,lastA,Pi,m,0,M,N,A),tau)
                a = np.random.choice(range(A),p=muP)
                lastA[m][0] = a

        Pi[0] = updatePiforTeam(Pi[0],lastA[0],k,1,1,A)
        Pi[1] = updatePiforTeam(Pi[1],lastA[1],k,1,N,A)
        k += 1
        
        if t%10 == 0:
            NG = teamNashGap8(randomRewards0,potentialFunction1,Pi,N)
            NGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(0)

        if t% 10000 == 0:
            print(NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)

    return PiHistory, NGHistory

def syncTeamFPNetowrk9(G,RAM,utility,Pi,lastA,M,N,A,T=100000,tau=0.1,groups=1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    for m in range(1,M):
        potentialFunctions[m] = potentialFunctions[0].copy()
    for t in range(T):
        for m in range(M):
            n = np.random.choice(np.arange(N)) #choose n th player randomly
            #mu = calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau)

            muP = softmax(expectedPotential(G,RAM,potentialFunctions,Pi,lastA,m,n,M,N,A,groups=groups),tau)
            if groups == 1:
                a = np.random.choice(range(A),p=muP.squeeze())
                lastA[m,n] = a #change the action of the chosen player
            else:
                a = np.random.choice(range(A**groups),p=muP.squeeze())
                actProfile = getActProfile(a,A,groups)
                lastA[m,groups*(n//groups):groups*(n//groups)+groups] = actProfile


        Pi = updatePi(Pi,lastA,k,M,N,A)
        k += 1
        
        if t%10 == 0:
            NG = teamNashGap(potentialFunctions,Pi)
            NGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(Pi[:,:,None].copy())

        if t% 10000 == 0:
            print(NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)

    return PiHistory, NGHistory

def independentTeamFPNetowrk(G,RAM,utility,Pi,lastA,M,N,A,T=100000,tau=0.1,delta=0.1,groups=1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    for t in range(T):
        for m in range(M):
            for n in range(N):
                if np.random.rand()<delta:
                    #mu = calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau)
                    muP = softmax(expectedPotential(G,RAM,potentialFunctions,Pi,lastA,m,n,M,N,A,groups=groups),tau)
                    a = np.random.choice(range(A),p=muP.squeeze())
                    
                    lastA[m,n] = a #change the action of the chosen player

        Pi = updatePi(Pi,lastA,k,M,N,A)
        k += 1
        
        if t%10 == 0:
            NG = teamNashGap(potentialFunctions,Pi)
            NGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(Pi[:,:,None].copy())

        if t% 10000 == 0:
            print(delta,t,NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)

    return PiHistory, NGHistory


def MWU(G,RAM,utility,Pi,PiInd,PiAvg,lastA,M,N,A,T=100000,tau=0.1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    actualNGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    allActProfiles = np.array(getActProfileList(A,M*N))
    allRewards = calculateAllRewards(G,RAM,utility,edgeUtilities,M,N,A)

    for t in range(T):
        expectedPayoffs = []
        for m in range(M):
            for n in range(N):
                #expPayoff = expectedPayoffMWU(G,RAM,utility,PiInd,m,n,M,N,A)
                expPayoff = expectedRewardsForMWU_FP(allRewards[m*N+n],allActProfiles,PiInd,m,n,M,N,A)
                expectedPayoffs.append(expPayoff)
                a = np.random.choice(np.arange(A),p=PiInd[m*N+n,:])
                lastA[m,n]=a
        
        expectedPayoffs = np.vstack(expectedPayoffs)

        Pi = updatePi(Pi,lastA,k,M,N,A)
        PiInd = MWUpdate(PiInd,expectedPayoffs,tau)
        PiAvg = updatePiAvg(PiAvg,lastA,k,M,N,A)
        k += 1
        
        if t%10 == 0:
            TNG = teamNashGap(potentialFunctions,Pi)
            #NG = nashGap(utility,PiAvg,A,G)
            NG = nashGap(G,RAM,utility,allRewards,allActProfiles,PiAvg,m,n,M,N,A,perspectivePlayer=None)
            NGHistory.append(TNG)
            actualNGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(Pi[:,:,None].copy())

        if t% 1000 == 0:
            print("TNG: ", TNG)
            print("NG: ", NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)
    actualNGHistory = np.array(actualNGHistory)

    return PiHistory, NGHistory, actualNGHistory

def FP(G,RAM,utility,Pi,PiInd,PiAvg,lastA,M,N,A,T=100000,tau=0.1):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    actualNGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    allActProfiles = np.array(getActProfileList(A,M*N))
    allRewards = calculateAllRewards(G,RAM,utility,edgeUtilities,M,N,A)

    for t in range(T):
        expectedPayoffs = []
        for m in range(M):
            for n in range(N):
                #expPayoff = expectedPayoffMWU(G,RAM,utility,PiAvg,m,n,M,N,A)
                expPayoff = expectedRewardsForMWU_FP(allRewards[m*N+n],allActProfiles,PiAvg,m,n,M,N,A)
                expectedPayoffs.append(expPayoff)
                mu = softmax(expPayoff.squeeze(),tau)
                a = np.random.choice(np.arange(A),p=mu)
                lastA[m,n]=a
        
        expectedPayoffs = np.vstack(expectedPayoffs)

        Pi = updatePi(Pi,lastA,k,M,N,A)
        PiAvg = updatePiAvg(PiAvg,lastA,k,M,N,A)
        k += 1
        
        if t%10 == 0:
            TNG = teamNashGap(potentialFunctions,Pi)
            #NG = nashGap(utility,PiAvg,A,G)
            #NG = nashGap(G,RAM,utility,PiAvg,m,n,M,N,A,perspectivePlayer=None)
            NG = nashGap(G,RAM,utility,allRewards,allActProfiles,PiAvg,m,n,M,N,A,perspectivePlayer=None)
            NGHistory.append(TNG)
            actualNGHistory.append(NG)

        if t % 10 == 0:
            PiHistory.append(Pi[:,:,None].copy())

        if t% 1000 == 0:
            print("TNG: ", TNG)
            print("NG: ", NG)
    
    PiHistory = np.concatenate(PiHistory,axis=2)
    NGHistory = np.array(NGHistory)
    actualNGHistory = np.array(actualNGHistory)

    return PiHistory, NGHistory, actualNGHistory


def stochasticTeamFP(G,RAM,utility,Pi,Qs,states,P,p0,lastA,M,N,A,T=100000,H=3,tau=0.1,delta=0.1,groups=1,modelBased=True):
    k = 0 # initialize stage no
    PiHistory = []
    NGHistory = []
    muHistory = {(0,0):[],(0,1):[],(1,0):[],(1,1):[]}
    allActProfiles = getActProfileList(A,M*N)
    allRewards = []
    for s in states:
        allRewards.append(calculateAllRewards(G,RAM,utility[s],edgeUtilities[s],M,N,A))
    counterSHA = np.zeros((len(states),H,A**(M*N)))
    counterSH = np.zeros((len(states),H))
    
    for t in range(T):
        trajectoryH = []
        for h in range(H):
            if h == 0:
                state = np.random.choice(states,p=p0) #initial state
            else:
                lastActionsIdx = getActProfileIdx(lastA[:,:,state,h-1].flatten(),A)
                state = np.random.choice(states,p=P[lastActionsIdx][state,:]) # state transition 

            #Qhat = re
            for m in range(M):
                n = np.random.choice(np.arange(N)) #choose n th player randomly
                #mu = calculateMu(G,RAM,utility,Pi,lastA,m,n,M,N,A,tau)

                muP = softmax(brQ(Qs[m*N+n,state,h,:],lastA[m,:,state,h],Pi[:,:,state,h],m,n,M,N,groups),tau)
                a = np.random.choice(np.arange(A),p=muP.squeeze())
                lastA[m,n,state,h] = a #update last actions
            
            actionIndex = getActProfileIdx(lastA[:,:,state,h].flatten(),A)
            r = allRewards[state][:,actionIndex]

            counterSHA[state,h,actionIndex] += 1
            counterSH[state,h] += 1

            trajectoryH.append((state,lastA[:,:,state,h],r))


        v = np.zeros((M*N,len(states),H))
        for m in range(M):
            for n in range(N):
                for s in states:
                    for h in range(H):
                        v[m*N+n,s,h] = brQ(Qs[m*N+n,s,h,:],lastA[m,:,s,h],Pi[:,:,s,h],m,n,M,N)[lastA[m,n,s,h]]

        for h in range(H):
            s_h = trajectoryH[h][0]
            a_h = trajectoryH[h][1]
            r_h = trajectoryH[h][2]
            if h<H-1:
                s_hp1 = trajectoryH[h+1][0]
            else:
                s_hp1 = s_h
            actionIndex = getActProfileIdx(a_h.flatten(),A)
            if modelBased:
                Qs[:,s_h,h,:] = updateQ(Qs[:,s_h,h,:],v,allRewards[s_h],h,s_h,s_hp1,r_h,M,N,counterSH[s_h,h],counterSHA[s_h,h,actionIndex],P,modelBased)
            else:
                Qs[:,s_h,h,actionIndex] = updateQ(Qs[:,s_h,h,actionIndex],v,allRewards[s_h],h,s_h,s_hp1,r_h,M,N,counterSH[s_h,h],counterSHA[s_h,h,actionIndex],P,modelBased)                
            Pi[:,:,s_h,h] = updatePi(Pi[:,:,s_h,h],a_h,counterSH[s_h,h],M,N,A)

        if t%10 == 0:
            TNG = valueIterationTNG(potentialFunctions,P,Pi,p0,states,H,M)
            NGHistory.append(TNG)
        if t%1000 == 0:
            print(TNG)
    
    return PiHistory,NGHistory,[]


def updateQ(Q,v,rewards,h,s_h,s_hp1,r_h,M,N,c1,c2,P,modelBased):
    if modelBased:
        Qupdated = np.zeros((M*N,A**(M*N)))
        for m in range(M):
            for n in range(N):
                if h<H-1:
                    Qhat = rewards[m*N+n] + P[:,s_h,:]@v[m*N+n,:,h+1]
                else:
                    Qhat = rewards[m*N+n]
                Qupdated[m*N+n,:] = (1-1/c1)*Q[m*N+n,:] + (1/c1)*Qhat
    
    else:
        Qupdated = np.zeros(M*N)
        for m in range(M):
            for n in range(N):
                if h<H-1:
                    Qhat = r_h[m*N+n] + v[m*N+n,s_hp1,h+1]
                else:
                    Qhat = r_h[m*N+n]
                Qupdated[m*N+n] = (1-1/c2)*Q[m*N+n] + (1/c2)*Qhat
    return Qupdated

def brQ(Q,lastA,Pi,m,n,M,N,groups=1):
    matrixQ = Q.reshape((A**N,)*M)
    #pi_m = Pi[m,:]
    for j in range(m):
        matrixQ = np.einsum('i...,i',matrixQ,Pi[j,:])
    for j in range(M-m-1):
        matrixQ = (matrixQ @ Pi[M-j-1,:]).squeeze()
    
    expectedPotentialForSelf = np.zeros((A**groups,1))
    actProfile = lastA.copy()
    for a in range(A**groups):
        groupActProfile = getActProfile(a,A,groups)
        actProfile[groups*(n//groups):groups*(n//groups)+groups] = groupActProfile
        actProfileIdx = getActProfileIdx(actProfile,A)
        expectedPotentialForSelf[a] = matrixQ[actProfileIdx]

    return expectedPotentialForSelf

def valueIterationTNG(potentialFunctions,P,Pi,p0,states,H,M):

    TNG = 0
    for m in range(M):
        v = np.zeros((len(states),H))
        vpi = np.zeros((len(states),H))
        for h in range(H)[::-1]:
            for s in states:
                teamPotentialFunction = potentialFunctions[s][m]
                for j in range(m):
                    teamPotentialFunction = np.einsum('i...,i',teamPotentialFunction,Pi[j,:,s,h])
                for j in range(M-m-1):
                    teamPotentialFunction = (teamPotentialFunction @ Pi[M-j-1,:,s,h]).squeeze()
                if h == H-1:
                    v[s,h] = max(teamPotentialFunction)
                    vpi[s,h] = Pi[m,:,s,h]@teamPotentialFunction
                else:
                    Pr = P[:,0,:].reshape((A**N,)*M+(len(states),))
                    for j in range(m):
                        Pr = np.einsum('i...,i',Pr,Pi[j,:,s,h])
                    for j in range(M-m-1):
                        Pr = np.einsum('...ij,i',Pr,Pi[j,:,s,h]).squeeze()

                    v[s,h] = max(teamPotentialFunction + Pr@v[:,h+1])
                    vpi[s,h] = Pi[m,:,s,h]@(teamPotentialFunction + Pr@v[:,h+1])
        TNG += p0@v[:,0] - p0@vpi[:,0]
    return TNG


#region Game Settings
# M = 2 #number of teams
# N = 3 #number of players in each team
# A = 2 #number of actions for each player. Possible actions are {0,1,2,...,A-1}
#loadGraphSettingFile = "simResults2_graph2_3.pkl"
#loadGraphSettingFile = "simResults_tau0_1.pkl" # N=3,M=3
#loadGraphSettingFile = "simResults5_tau0_1_delta_0_1.pkl"
#oadGraphSettingFile = None
#loadGraphSettingFile = "graph_21_05_24.pkl" #equivalent to simResults_21_05_24_setting_2_tau0_1_delta_0_1.pkl , N=4, M=2
#loadGraphSettingFile = "simResults_21_05_24_setting_2_tau0_1_delta_0_1.pkl"
loadGraphSettingFile = None

M = 3
N = 3
A = 2

K = 10 #number of independent trials
noIterations = 1e6
p=0.3 #random edge connectivity probability
tauList=[0.1]
deltaList = [0]  #independent update probability of independent team-fp
settingList = [0] #Setting = 0 means Team-FP, setting = 1 means independent Team-FP
groupsList = [1] #Only for setting 0 or setting 1.

if loadGraphSettingFile != None:  #load a previous graph setting if it is given.      
    with open(loadGraphSettingFile,"rb") as fp:
        _,_,G,edgeUtilities,utility,RAM,potentialFunctions = pickle.load(fp)
else:
    G,RAM = generateNetwork(M,N,p) #generate new random graph (game).
    utility, edgeUtilities = generateEdgeUtilities(G,A)
    potentialFunctions = generatePotentialFunctions(G,RAM,utility,M,N,A)

if (4 in settingList) or (5 in settingList): #This part is for Markov games
    states = np.array([0,1]) # the list of states
    H = 10 # Horizong length

    if loadGraphSettingFile != None:  #load a previous graph setting if it is given.      
        with open(loadGraphSettingFile,"rb") as fp:
            _,_,_,G,edgeUtilities,utility,RAM,potentialFunctions = pickle.load(fp)
    else:
        utility = []
        edgeUtilities = []
        potentialFunctions = []
        G,RAM = generateNetwork(M,N,p) #generate new random graph (game).
        for s in states:
            utility1, edgeUtilities1 = generateEdgeUtilities(G,A)
            potentialFunctions1 = generatePotentialFunctions(G,RAM,utility1,edgeUtilities1,M,N,A)
            utility.append(utility1)
            edgeUtilities.append(edgeUtilities1)
            potentialFunctions.append(potentialFunctions1)


    P = generate_N_random_stochastic_matrices(A**(M*N),len(states))
    p0 = np.ones(len(states))/len(states) #start with equal probabilities

    Pi = np.ones((M , (A**N),len(states),H)) #beliefs of players about others, (average play of players in this case)
    Pi = Pi/(A**N)

    Qs = np.zeros((M*N,len(states),H,A**(M*N)))

    PiInd = np.ones((M*N,A))/A #Individual beliefs (average play of each player)
    PiAvg = np.ones((M*N,A))/A #Average Plays of individuals

    lastA = np.zeros((M,N,len(states),H),dtype=np.int32) #last actions of players


for delta in deltaList:
    for tau in tauList:
        for setting in settingList:
            for group in groupsList:

                Pi = np.ones((M * (A**N),1)) #beliefs of players about others, (average play of players in this case)
                Pi = Pi.reshape(M,A**N)
                Pi = Pi/(A**N)

                PiInd = np.ones((M*N,A))/A #Individual beliefs (average play of each player)
                PiAvg = np.ones((M*N,A))/A #Average Plays of individuals

                lastA = np.zeros((N*M,1),dtype=np.int32) #last actions of players
                lastA = lastA.reshape(M,N)

                #endregion

                allPiHistory = []
                allNGHistory = []
                allNGActualHistory = []

                # This part is only about decentralized batch learning of Q function, ignore for other things.
                # alpha = 0.1
                # trajectoryLength = 10000
                # DIGingLength  = 1000

                # allRewards, allActProfiles = generateRandomTrajectory(G,utility,M,N,A,L=trajectoryLength)
                # Cs = generateConsensusMatrix(RAM,M,N)
                # Phi,thetas= initializeLinearQEstimates(G,M,N,A)
                # DIGing(thetas,alpha,Cs,allRewards,allActProfiles,TrajectoryL=trajectoryLength,L=DIGingLength)
                # End of this part.

                for k in range(K): #Try each configuration K times.
                    if setting == 0:
                        PiHistory, NGHistory = syncTeamFPNetowrk(G,RAM,utility,Pi,lastA,M,N,A,T=int(noIterations),tau=tau,groups=group)
                        NGActualHistory = 0
                    elif setting == 1:
                        PiHistory, NGHistory = independentTeamFPNetowrk(G,RAM,utility,Pi,lastA,M,N,A,T=int(noIterations),tau=tau,delta=delta,groups=group)
                        NGActualHistory = 0
                    elif setting == 2: #MWU
                        PiHistory, NGHistory,NGActualHistory = MWU(G,RAM,utility,Pi,PiInd,PiAvg,lastA,M,N,A,T=int(noIterations),tau=tau)
                    elif setting ==3: #Full FP
                        PiHistory, NGHistory,NGActualHistory = FP(G,RAM,utility,Pi,PiInd,PiAvg,lastA,M,N,A,T=int(noIterations),tau=tau)
                    elif setting ==4: #Model Based Markov Game
                        PiHistory, NGHistory,NGActualHistory = stochasticTeamFP(G,RAM,utility,Pi,Qs,states,P,p0,lastA,M,N,A,T=int(noIterations),H=H,tau=tau,delta=delta,groups=group,modelBased=True)
                    elif setting ==5: #Model Free Markov Game
                        PiHistory, NGHistory,NGActualHistory = stochasticTeamFP(G,RAM,utility,Pi,Qs,states,P,p0,lastA,M,N,A,T=int(noIterations),H=H,tau=tau,delta=delta,groups=group,modelBased=False)
                    elif setting == 7: #team vs stationary opponents
                        PiHistory, NGHistory = syncTeamFPNetowrk7(G,RAM,utility,Pi,lastA,M,N,A,T=int(noIterations),tau=tau,groups=group)
                        NGActualHistory = 0
                    elif setting == 8: #1 person with 2 action vs 1 team (1 person has random rewards)
                        PiHistory, NGHistory = syncTeamFPNetowrk8(G,RAM,utility,Pi,lastA,M,N,A,T=int(noIterations),tau=tau,groups=group)
                        NGActualHistory = 0
                    elif setting == 9: #Potential of potential
                        PiHistory, NGHistory = syncTeamFPNetowrk9(G,RAM,utility,Pi,lastA,M,N,A,T=int(noIterations),tau=tau,groups=group)
                        NGActualHistory = 0 

                    if setting == 4 or setting == 5:
                        Pi = np.ones((M , (A**N),len(states),H)) #beliefs of players about others, (average play of players in this case)
                        Pi = Pi/(A**N)

                        Qs = np.zeros((M*N,len(states),H,A**(M*N)))

                        PiInd = np.ones((M*N,A))/A #Individual beliefs (average play of each player)
                        PiAvg = np.ones((M*N,A))/A #Average Plays of individuals

                        lastA = np.zeros((M,N,len(states),H),dtype=np.int32) #last actions of players
                    else:
                        Pi = np.ones((M * (A**N),1)) #beliefs of players about others, (average play of players in this case)
                        Pi = Pi.reshape(M,A**N)
                        Pi = Pi/(A**N)
                        lastA = np.zeros((N*M,1),dtype=np.int32) #last actions of players
                        lastA = lastA.reshape(M,N)

                        PiInd = np.ones((M*N,A))/A #Individual beliefs (average play of each player)
                        PiAvg = np.ones((M*N,A))/A #Average Plays of individuals

                    allPiHistory.append(PiHistory)
                    allNGHistory.append(NGHistory)
                    allNGActualHistory.append(NGActualHistory)

                    print(f"tau: {tau} , epoch: {k}")
                    t2 = perf_counter()

                    print(f"The process for {K} trials, where each trial is {int(noIterations)} iterations and \n setting {setting}, is a graph with {M*N} nodes with {M} teams and {N} nodes within each team. \n edge connectivity probability is given as {p}")

                    with open("simResults3_21_05_24_setting_"+str(setting)+"_groups_"+str(group)+"_tau"+str.replace(str(tau),'.','_')+"_delta_"+str.replace(str(delta),'.','_')+".pkl","wb") as fp:
                        pickle.dump([allPiHistory, allNGHistory,allNGActualHistory,G,edgeUtilities,utility,RAM,potentialFunctions],fp)


print("Done")

"""
# Example Plotting part
#%%

#Game settings
tauList = [0.1,0.15,0.2]
tauList = [0.1]
deltaList = [0.1]
settingList = [0,1,2,3]
#tauList = [0,0.3,0.5]
A = 2
M = 2
N = 4

labels = {0:"Team-FP",1:"Independent Team-FP",2:"MWU",3:"Smoothed FP"}
plt.figure()

for tau in tauList:
    for delta in deltaList:
        for setting in settingList:
            with open("simResults_21_05_24_setting_"+str(setting)+"_tau"+str.replace(str(tau),'.','_')+"_delta_"+str.replace(str(delta),'.','_')+".pkl","rb") as fp:
                allPiHistory, allNGHistory,allNGActualHistory,G,edgeUtilities,utility,RAM,potentialFunctions = pickle.load(fp)

            bound1 = tau*M*np.log(A**M)

            allPi = np.stack(allPiHistory,axis=3)
            allNG = np.stack(allNGHistory,axis=1)
            newNG = allNG

            newNG = newNG[1:]
            meanNG = np.mean(newNG,axis=1)
            stdNG = np.std(newNG,axis=1)

            t = np.arange(10,10*len(newNG)+10,10)

            lines = plt.plot(t,meanNG, label=labels[setting])

            plt.xscale('log',base=10)
            plt.fill_between(x=t,y1=meanNG-stdNG,y2=meanNG+stdNG,where=t<len(newNG)*10,color=lines[0].get_color(),alpha=0.1)

plt.legend()
plt.ylabel("Team Nash Gap")
plt.xlabel("Number of Iterations")
plt.ylim([0,3])
plt.xlim([10,1e6])
plt.title("Team-FP vs. MWU and SFP")
plt.savefig("TeamFPvsOthers.eps",format='eps')

"""