import numpy as npo
import matplotlib.pyplot as plt
import scipy
plt.style.use("ggplot")
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

# def project2simplex(p):
#     x = p
#     phat = []
#     row = x[idx,...]
#     xinit = npo.ones_like(row) / npo.size(row)
#     J = lambda x : 1/2 * ( npo.linalg.norm(x - row) )**2
#     cons = scipy.optimize.LinearConstraint( A=npo.ones_like(row), lb=1.0, ub=1.0)
#     bnds = scipy.optimize.Bounds( npo.zeros_like(row) ,  npo.ones_like(row) )
#     sol = scipy.optimize.minimize( J, xinit, bounds=bnds, constraints=cons)
#     phat.append( sol.x[None,...] )
#     return npo.concatenate(phat, axis=0 )

# def vec2simplex(row):
#     # row = p
#     phat = []
#     xinit = npo.ones_like(row) / npo.size(row)
#     J = lambda x : 1/2 * ( npo.linalg.norm(x - row) )**2
#     cons = scipy.optimize.LinearConstraint( A=npo.ones_like(p), lb=1.0, ub=1.0)
#     bnds = scipy.optimize.Bounds( 0 ,  1 )
#     sol = scipy.optimize.minimize( J, xinit, bounds=bnds, constraints=cons)
# #     phat.append( sol.x[None,...] )
#     return sol.x

def vec2simplex(y):
    """Python implementation of:
    https://arxiv.org/abs/1101.6081"""
    s = npo.sort(y)
    n = len(y) ; flag = False
    
    parsum = 0
    tmax = -npo.inf
    for idx in range(n-2, -1, -1):
        parsum += s[idx+1]
        tmax = (parsum - 1) / (n - (idx + 1) )
        if tmax >= s[idx]:
            flag = True ; break
    
    if not flag:
        tmax = (npo.sum(s) - 1) / n
    
    return npo.maximum(y - tmax, 0)

def project2simplex(p):
    x = p
    phat = []
    for idx in range(p.shape[0]):
        row = x[idx,...]
        # xinit = npo.ones_like(row) / npo.size(row)
        # J = lambda x : 1/2 * ( npo.linalg.norm(x - row) )**2
        # cons = scipy.optimize.LinearConstraint( A=npo.ones_like(row), lb=1.0, ub=1.0)
        # bnds = scipy.optimize.Bounds( npo.zeros_like(row) ,  npo.ones_like(row) )
        # sol = scipy.optimize.minimize( J, xinit, bounds=bnds, constraints=cons)
        phat.append( vec2simplex(row) )
#         print(row, sol.x[None,...])
#         print(sol.x)
    return npo.concatenate(phat, axis=0 )


def squeezeStrategy(jpMat):
    zz = jpMat[0,0]
    oz = jpMat[1,0] + jpMat[0,1]
    oo = jpMat[1,1]
    
    return npo.array([zz, oz, oo])

def map2subface(sqzMat):
#     sqMat = squeezeStrategy(Mat)
    theta = npo.linspace( 0, 2*npo.pi, 4 )[:-1]
    basis = npo.e**(1j * theta)
#     try:
#         xy = npo.sum( basis * sqMat, axis = 1 )
#     except:
    xy = npo.sum( basis * npo.sqrt(2)/2 * sqzMat )
    xy = npo.array( [npo.real(xy), npo.imag(xy)] )
    return xy

def joint(M):
    try:
        z = npo.outer(M[0,:], M[1,:])
        return z
    except:
        mid = int( len(M)//2 )
        return npo.outer(M[:mid], M[mid:])

# def projectTeams(A, B, simplex=True, product=[1,1]):

#     j0 = lambda x : joint(x) if product[0] == 1 else x
#     j1 = lambda x : joint(x) if product[1] == 1 else x
#     if simplex:
#         h0 = lambda x : squeezeStrategy( j0( project2simplex(x) ) )
#         h1 = lambda x : squeezeStrategy( j1( project2simplex(x) ) )
#     else:
#         h0 = lambda x : squeezeStrategy( j0(x) )
#         h1 = lambda x : squeezeStrategy( j1(x) )

#     proja = npo.zeros( (A.shape[2], 2) )
#     projb = npo.zeros( (B.shape[2], 2) )

#     for idx in range(A.shape[2]):
#         proja[idx, :] = map2subface( h0(A[..., idx]) )
#     for idx in range(A.shape[2]):
#         projb[idx, :] = map2subface( h1(B[..., idx]) ) 
#     return proja, projb
def projectTeams(A, B, simplex=True, product=[1,1]):
    """ Project both Teams to Simplex
    A shape: Players x Pure_Strategies x Length
    B shape: Players x Pure_Strategies x Length
    """

    proja = npo.zeros( (A.shape[2], 2))
    projb = npo.zeros( (B.shape[2], 2))

    for idx in range(A.shape[2]):
        if product[0] == 1:
            if simplex:
                proja[idx, :] = map2subface( squeezeStrategy( joint( project2simplex(A[...,idx]) ) ) )
            else:
                proja[idx, :] = map2subface( squeezeStrategy( joint(A[...,idx] ) ) )
        else:
            if simplex:
                z  = npo.reshape( vec2simplex( A[...,idx].flatten() ), (2,2) )
            else:
                z  = npo.reshape( A[...,idx].flatten(), (2,2) )
            proja[idx, :] = map2subface( squeezeStrategy ( z ) )

    for idx in range(B.shape[2]):
        if product[1] == 1:
            if simplex:
                projb[idx, :] = map2subface( squeezeStrategy( joint( project2simplex(B[...,idx]) ) ) )
            else:
                projb[idx, :] = map2subface( squeezeStrategy( joint(B[...,idx] ) ) )
        else:
            if simplex:
                z = npo.reshape( vec2simplex( B[...,idx].flatten() ), (2,2) )
            else:
                z = npo.reshape( B[...,idx].flatten(), (2,2) )
            projb[idx, :] = map2subface( squeezeStrategy ( z ) )
    return proja, projb

def getProjectedPlot(extras=True, c='k', ax=None):
    """Get the projected subface of the 3-simplex
    """
    FIGFLAG = True
    tri = []
    for idx in range(3):
    #     idx = 2
        a = npo.zeros(3)
        a[idx] = 1
    #     tri.append(a)
        x = map2subface( a )
        tri.append(x)
    tri = npo.asarray(tri)

    x, y = tri[:,0].tolist(), tri[:,1].tolist()
#    plt.figure(figsize=(8,8))
    x, y = tri[:,0].tolist(), tri[:,1].tolist()
    if ax is None:
        FIGFLAG = False
        fig, ax = plt.subplots(figsize=(12,12))
    ax.triplot(y, x, zorder=-1, label="Probability Simplex", c=c)
    # # plt.plot()
    if extras:
        y, x = map2subface([1/4, 2/4, 1/4])
        ax.scatter( x,y, marker="*", s=200, zorder=-1, label="", c=colors[4])
        ax.annotate("NE", (x,y))
    #
    if extras:
        y, x = map2subface([1, 0, 0])
        ax.scatter( x, y, zorder=-1, label="", c=c)
        ax.annotate("(1, 0, 0)\nHH", (x,y), xytext=(+5, 0), textcoords='offset points')
        #
        y, x = map2subface([0, 1, 0])
        ax.scatter( x, y, zorder=-1, label="", c=c )
        ax.annotate("(0, 1, 0)\nHT/TH", (x,y), xytext=(+5, -35), textcoords='offset points')
        #
        y, x = map2subface([0, 0, 1])
        ax.scatter( x, y, zorder=-1, label="", c=c)
        ax.annotate("(0, 0, 1)\nTT", (x,y), xytext=(-35, -35), textcoords='offset points')
        #
        y, x = map2subface([1/2, 0, 1/2])
        # ax.scatter( x, y, marker='*', s=200, zorder=1, label="", c=c)
        # ax.annotate("(1/2, 0, 1/2)", (x,y), xytext=(-70, 0), textcoords='offset points', )
        
        # y0, x0 = map2subface([0,1,0])
        # y1, x1 = map2subface([1/3, 1/3, 1/3])
        # ax.plot([x0,x1], [y0,y1], ':', c=colors[0], label="")


        y0, x0 = map2subface([1/3,1/3,1/3])
        y1, x1 = map2subface([1,0,0])
        y2, x2 = map2subface([0,0,1])
        ###
        lam = npo.linspace(0, 1, 1001)
        p1 = npo.array([1,0]) ; p2 = npo.array([0,1])
        z = p1 * lam [:, None] + (1 - lam[:,None]) * p2
        p = z[:,:,npo.newaxis] * z[:,npo.newaxis,:]
        f = lambda x : npo.array([x[0,0], x[0,1]+x[1,0], x[1,1]])

        ppath = npo.asarray( [ f( p[i,:, :] ) for i in range(p.shape[0])] )
        ppath = npo.apply_along_axis(map2subface, 1, ppath)
        t1 = map2subface([0,0,1]) ; t2 = map2subface([1,0,0])
        t = t1* lam[:,None] + (1-lam[:,None])*t2
        c = ( t[-1, 0] - t[0, 0]) / ( t[-1, 1] - t[0, 1])
        ax.fill_betweenx(y=ppath[:,0], x1=(ppath[:,0] - ppath[-1,0] )/c, x2=ppath[:, 1], color=colors[0], alpha=0.2, zorder=-1)

        # ax.fill([x0, x1, x2], [y0,y1, y2], c=colors[0], alpha=0.2, zorder=-1, label="")
    if not FIGFLAG:
        return fig, ax
    else:
        return ax




def plotGame(X, Y, c1=1, c2=4, fig=None, ax=None, flag=[1,1], product=[1,0], simplex=True, extras=False, projected=False, lw=0.6):
    ''' X: players x pure_strategies x rounds/length
        Y: players x pure_strategies x rounds/length
    '''
    FIGFLAG = True

    # if fig is None
    if fig is None or ax is None:
        FIGFLAG = False
        fig, ax = getProjectedPlot(extras=extras)

    if not projected:
        proja, projb = projectTeams(X, Y, simplex=simplex, product=product, )
    else:
        proja, projb = X, Y

    handles = {}
    if flag[0] > 0:
        # ax.scatter(proja[:,1], proja[:,0], s= 20, marker="+", c=colors[c1])
        m, = ax.plot(proja[:,1], proja[:,0], c=colors[c1], lw=lw)

        ax.scatter(proja[0,1], proja[0,0], s= 100, marker="o", c=colors[c1])
        ax.annotate('initA', (proja[0,1], proja[0,0]), xytext=(0, 10), textcoords='offset points', )
        ax.scatter(proja[-1,1], proja[-1,0], s= 100, marker="+", c=colors[c1])
        ax.annotate('finA', (proja[-1,1], proja[-1,0]),  xytext=(0, 20), textcoords='offset points', )

        handles[m] = "A-MIN player"

    if flag[1] > 0:
        # ax.scatter(projb[:,1], projb[:,0], s=20, marker="x", c=colors[c2])
        M, = ax.plot(projb[:,1], projb[:,0], c=colors[c2],lw=lw)

        ax.scatter(projb[0,1], projb[0,0], s=100, marker="o", c=colors[c2])
        ax.annotate('initB', (projb[0,1], projb[0,0]), xytext=(0, 10), textcoords='offset points', )
        ax.scatter(projb[-1,1], projb[-1,0], s=100, marker="x", c=colors[c2])
        ax.annotate('finB', (projb[-1,1], projb[-1,0]), xytext=(0, 10), textcoords='offset points',)
    
        handles[M] ="B-MAX player"
    # fig2, ax2 = plt.subplots()
    # ax2.plot()
    ax.legend( list(handles.keys() ), list(handles.values() ) )
    # if (not plotU) or projected:
    if not FIGFLAG:
        return fig, ax

    # fig2, ax2 = plt.subplots()
    # ax2.plot()
    else:
        return ax

def vectorize(x, k, m):
    idx  = npo.argwhere(A.shape==2*k*m)
    if idx == 0:
        A = x[:k*m, ... ]
        A = npo.reshape(A, (k, m, -1))
#     for i in range(npo.shape(A)[2]):
#         A[:,:,i] = project2simplex(A[:,:,i])
        B = x[k*m:2*k*m,...]
        B = npo.reshape(B, (k, m , -1))
    elif idx == 1:
        A = x[..., :k*m]
        A = npo.reshape(A, (k, m, -1))
#     for i in range(npo.shape(A)[2]):
#         A[:,:,i] = project2simplex(A[:,:,i])
        B = x[..., k*m:2*k*m]
        B = npo.reshape(B, (k, m , -1))        
#     for i in range(npo.shape(B)[2]):
#         B[:,:,i] = project2simplex(B[:,:,i])
    return A, B


def vectSol(sol, k=2, m=2):
    A, B = sol[...,:k*m], sol[...,k*m:2*k*m]
    A, B = npo.reshape( A, (-1, k, m) ), npo.reshape( B, (-1, k, m) )
    A, B = npo.moveaxis(A, [0, 1, 2], [2, 0, 1]), npo.moveaxis(B, [0,1,2], [2,0,1])
    return A, B

def init_player(k=2):
    l = npo.random.rand(k)
    l /= npo.sum(l)
    return l

def newteamproject(X, length=None):
    if length is None:
        L = npo.argmax(X.shape)
    else:
        L = length
    proja = npo.zeros( (X.shape[L], 2) )
    for idx in range( X.shape[L] ):
        x = X.take(indices=idx, axis=L)
        
        x = npo.outer( x[:2] , x[2:])
        x = squeezeStrategy(x) 
        x = map2subface(x)
        proja[idx, ...] = x
    return proja
