import numpy as np
import scipy as sp
import numpy.linalg as la
from sklearn import datasets
import sklearn.decomposition
from itertools import permutations
import tensorly as tn
import tensorly.decomposition as dc

# A sample is always treated as being a column

def otimes(a,b):    
    '''Returns the tensor product of a and b'''
    ashape = a.shape;
    bshape = b.shape;
    flatout = np.outer(a.flatten(),b.flatten());
    return np.squeeze(np.reshape(flatout,a.shape + b.shape))

def loadMnistPca(dim = 2):
    '''Returns the full mnist dataset projected to the top dim dimensions. Samples are in the columns.'''
    X = datasets.load_digits()['data']
    pcaer = sklearn.decomposition.PCA(n_components=dim)
    pcaer.fit(X)
    X = pcaer.transform(X)
    X = X.T
    X = dataPerm(X)
    return X

def loadDiabetesPca(dim = 2):
    '''Returns the full mnist dataset projected to the top dim dimensions. Samples are in the columns.'''
    X = datasets.load_diabetes()['data']
    pcaer = sklearn.decomposition.PCA(n_components=dim)
    pcaer.fit(X)
    X = pcaer.transform(X)
    X = X.T
    X = dataPerm(X)
    return X

def loadBreastPca(dim = 2):
    '''Returns the full mnist dataset projected to the top dim dimensions. Samples are in the columns.'''
    X = datasets.load_breast_cancer()['data']
    pcaer = sklearn.decomposition.PCA(n_components=dim)
    pcaer.fit(X)
    X = pcaer.transform(X)
    X = X.T
    X = dataPerm(X)
    return X

def loadWinePca(dim = 2):
    '''Returns the full mnist dataset projected to the top dim dimensions. Samples are in the columns.'''
    X = datasets.load_wine()['data']
    pcaer = sklearn.decomposition.PCA(n_components=dim)
    pcaer.fit(X)
    X = pcaer.transform(X)
    X = X.T
    X = dataPerm(X)
    return X

def loadBostonPca(dim = 2):
    '''Returns the full mnist dataset projected to the top dim dimensions. Samples are in the columns.'''
    X = datasets.load_boston()['data']
    pcaer = sklearn.decomposition.PCA(n_components=dim)
    pcaer.fit(X)
    X = pcaer.transform(X)
    X = X.T
    X = dataPerm(X)
    return X

def unifNormalize(X):
    '''Accepts an ndarray, scales entries so they all lie 0<=x<1'''
    X = X - np.min(X,1).reshape(-1,1)
    X = X/np.max(X,1).reshape(-1,1)/1.00001
    return X

def __histEntry__(X,binVec):
    '''Accepts a vector and returns its associated histogram entry'''
    return (X*binVec).astype(int)

def histTransform(X,binVec):
    '''Accepts a collection of datapoints, retunrs tensor respresenting histogram'''
    X = np.array(X)
    hist = np.zeros(binVec)
    for i in range(X.shape[1]):
        hist[tuple(__histEntry__(X[:,i],binVec))]+=1
    hist = hist/X.shape[1]
    return hist

def lrApprox(X,rank):
    [u,s,v] = sp.linalg.svd(X)
    lrX = np.zeros(X.shape)
    rank = min(rank,s.shape[0])
    for i in range(rank):
        lrX += s[i]*np.outer(u[:,i],v[i,:])
    return lrX

def simpProject(X):
    '''Accepts a tensor and returns projection onto probablistic simplex'''
    Xout = X.flatten()
    descArg = np.argsort(Xout)[::-1]
    mu = Xout[descArg]
    for rho in range(1,len(Xout)+1):  
        if mu[rho-1]-((np.sum(mu[:rho])-1.0)/rho) <= 0:
            break
    theta = (np.sum(mu[:rho]) - 1) / rho
    Xout = Xout - theta
    Xout = Xout + np.abs(Xout)
    Xout = Xout/2
    Xout = X.flatten()
    Xout[Xout<0] = 0
    Xout = Xout/np.sum(Xout)
    return np.reshape(Xout,X.shape)

def likelihood(hist,numsamp): #approximation because I'm lazy
    return np.prod(hist.shape)/(numsamp-1)*(numsamp*np.linalg.norm(hist,'fro')**2-1)

def dataPerm(X):
    '''Permutes the columns of a matrix X, i the column vectors are preserved but are in different places'''
    return X[:,np.random.permutation(X.shape[1])]
        
def genInner(x,y):
    '''Accepts two tensors and returns their inner product'''
    return np.inner(x.flatten(),y.flatten())

def genHistInner(x,y):
    '''Accepts two tensors which represent histograms and returns their inner product'''
    return np.product(x.shape)*genInner(x,y)

def genHistXent(p,q):
    return -(genInner(p,np.nan_to_num(np.log(q))) + np.sum(np.log(q.shape)))

def lrHist(X,binVec,rank):
    ''' Accepts a hist matrix and returns the low rank approximation projected to the probabilistic simplex'''
    Xout = histTransform(X,binVec)
    Xout = lrApprox(Xout,rank)
    return simpProject(Xout)

def randHist(binVec):
    return np.reshape(np.random.dirichlet([1]*np.prod(binVec)),binVec)

def nnegParfac(inHist,rank,inInit = 'svd'): #Note that this isn't exactly correct
    return(simpProject(tn.kruskal_to_tensor(dc.non_negative_parafac(inHist,rank,init = inInit))))

def nnegTucker(inHist,rank,inInit = 'svd'):
    order = len(inHist.shape)
    core, factors = dc.non_negative_tucker(inHist,[rank]*order,init = inInit)
    return(simpProject(tn.tucker_to_tensor(core,factors)))

def normHist(D,numComponents,dbg = False):
    '''accepts a histogram tensor D and returns a low rank tensor fitting it'''
    #setup
    bins = D.shape
    numModes = len(bins)
    componentWeights = randHist(numComponents)
    components = [] #list of ndarrays, entry corresponds to the mode, columns are correspond to the component number

    for i in range(numModes):
        components.append(randHist([bins[i],numComponents]))

    def colNormalize(X):
        for i in range(X.shape[1]):
            X[:,i] = X[:,i]/np.sum(X[:,i])
        return X

    for i in range(numModes):
        components[i] = colNormalize(components[i])

    for _ in range(10):     #number of times to run one full pass of alternating minimization
        for mode in range(numModes):
            hX = components[mode]@np.diag(componentWeights)
            hY = np.zeros([int(np.prod(bins)/bins[mode]),numComponents]) ### Not actually a hist, but columns should be probability vectors

            ###construct hY###
            for ycomp in range(numComponents):
                curComp = np.ones([1])
                for ymode in range(numModes):
                    if ymode != mode:
                        curComp = otimes(curComp,components[ymode][:,ycomp])
                hY[:,ycomp] = curComp.flatten()
            ##################

            for _ in range(100): #number of projected gradient descent to run
                grad = 2*hX@hY.T@hY - 2*tn.unfold(D,mode)@hY
                lam = 0.1
                hXp = hX - lam*grad
                hX = simpProject(hXp)
            componentWeights = np.sum(hX,0)
            components[mode] = colNormalize(hX)
            
    return components,componentWeights

def reconHist(inModel):
    components = inModel[0]
    componentWeights = inModel[1]
    histOut = None
    for component in range(len(componentWeights)):
        currComp = componentWeights[component] * np.ones([1])
        for mode in range(len(components)):
            currComp = otimes(currComp,components[mode][:,component])
        if histOut is None:
            histOut = currComp
        else:
            histOut += currComp
    return histOut
            

#TODO
def mlHist(hist,c):
    bins = hist.shape[0]
    w = np.random.dirichlet([1]*c)
    q = np.zeros([bins,c])
    r = np.zeros([bins,c])
    for i in range(c):
        q[:,i] = np.random.dirichlet([1]*bins)
        r[:,i] = np.random.dirichlet([1]*bins)
    numCycles = 100
    z = np.zeros([bins,bins,c])
    for foo in range(numCycles): 
        #Set up unnormalized z
        for i in range(bins):
            for j in range(bins):
                for k in range(c):
                    z[i,j,k] = q[i,k]*r[j,k]*w[k]
        #Normalize z
        normMat = np.sum(z,2)
        for i in range(bins):
            for j in range(bins):
                z[i,j,:] = z[i,j,:]/normMat[i,j]
        #update q and r
        for i in range(c):
            q[:,i] = np.sum(hist*z[:,:,i],1)
            r[:,i] = np.sum(hist*z[:,:,i],0)
            q[:,i] = q[:,i]/np.sum(q[:,i])
            r[:,i] = r[:,i]/np.sum(r[:,i])
        #update w
        for i in range(c):
            w[i] = np.dot(hist.flatten(),z[:,:,i].flatten())
        w = w/np.sum(w)
    outHist = np.zeros([bins]*2)
    for i in range(c):
        outHist += np.outer(q[:,i],r[:,i])*w[i]
    return outHist


