import numpy as np
import torch
from scipy import stats
import time
from matplotlib import pyplot as plt
import pickle as pkl
import math

def write_bytes(obj, fpath):
    with open(fpath, 'wb') as f:
        pkl.dump(obj, f)

def load_bytes(fpath):
    with open(fpath, 'rb') as f:
        obj = pkl.load(f)
    return obj 

def randomStratifiedSampleData(label, perN):
    selected_indices = []
    classes = set(label)

    if max(classes) == len(classes):
        print("LabelID starts from 1!")
        zero_index = False
    else:
        zero_index = True
        print("LabelID starts from 0!")

    for lpw in range(len(classes)):
        if zero_index:
            l1 = np.argwhere(label == lpw)[:, 0]
        else:
            l1 = np.argwhere(label == (lpw+1))[:, 0]
        
        lp = np.random.permutation(len(l1))
        l1 = l1[lp]
        selected_indices.extend(l1[:perN[lpw]])

    selected_indices = np.array(selected_indices)
    return selected_indices


def generateRandomLabelNoise2(TtrY, noiserate):
    if not isinstance(TtrY, np.ndarray):
        TtrY = TtrY.numpy()

    classes = np.unique(TtrY)
    numClasses = len(classes)
    trY = TtrY.copy()
    if noiserate < 1e-4:
        return trY

    for lpw in classes:
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserate).astype(int)
        lu = np.random.random_integers(2, numClasses - 1, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (TtrY[la[rnla[:N1]]] + lu) % numClasses

    return trY

def generateImBalancedRandomLabelNoise2(TtrY, noiserates):
    if not isinstance(TtrY, np.ndarray):
        TtrY = TtrY.numpy()

    classes = np.unique(TtrY)
    numClasses = len(classes)
    trY = TtrY.copy()

    for lpw in classes:
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserates[lpw]).astype(int)
        lu = np.random.random_integers(2, numClasses - 1, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (TtrY[la[rnla[:N1]]] + lu) % numClasses
        
    return trY
    
def randomSampleData(data, label, perN):
    sdata = []
    slabel = []
    numClasses = max(label)

    for lpw in range(1, numClasses + 1, 1): # label from 1 -> max 
        l1 = np.argwhere(label==(lpw))[:, 0]
        lp = np.random.permutation(len(l1))
        l1 = l1[lp]
        sdata.extend(data[l1[:perN[lpw-1]],:])
        slabel.extend(label[l1[:perN[lpw-1]]])
    # sdata = double(sdata)
    # slabel = double(slabel)
    sdata = np.array(sdata)
    slabel = np.array(slabel)
    return sdata, slabel


def generateRandomLabelNoise2_Origin(TtrY, noiserate):
    numClasses = max(TtrY)
    trY = TtrY.copy()
    fl = 0
    if noiserate < 1e-4:
        return trY, fl

    for lpw in range(1, numClasses+1, 1):
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserate).astype(int)
        lu = np.random.random_integers(1, numClasses - 2, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (trY[la[rnla[:N1]]] + lu) % numClasses + 1
        
    fl = sum(trY != TtrY)/len(TtrY)
    return trY, fl

def getCleanTe(data, labels, labels_gt, npte):
    numClasses = max(labels) + 1

    ntr = data.shape[0] - numClasses*npte
    nte = numClasses*npte

    trX = np.zeros((ntr, data.shape[1]))
    teX = np.zeros((nte, data.shape[1]))

    trY = np.zeros(ntr)
    teY = np.zeros(nte)
    trY_gt = np.zeros(ntr)
    teY_gt = np.zeros(nte)
    iend = 0
    for lpw in range(numClasses):
        # %     fprintf('class %d in %d\n',lpw,numClasses)
        lt = np.argwhere(labels == lpw)[:, 0]
        np.random.shuffle(lt)

        curX = data[lt, :]
        curY = labels[lt]
        curY_gt = labels_gt[lt]
        lte = np.argwhere(curY==curY_gt)[:, 0]
        
        if(len(lte) < npte):
            print("Not enough data for testing")
            break
        
        ltr = np.ones(len(lt))
        ltr[lte[:npte]] = 0  # not pick test samples
        ltr = ltr.astype(bool)
        
        xr = curX[ltr,:]
        yr = curY[ltr]
        yr_gt = curY_gt[ltr]
        
        xe = curX[lte[:npte],:]
        ye = curY[lte[:npte]]
        ye_gt = curY_gt[lte[:npte]]

        ibegin = iend
        iend = ibegin + len(yr)
        trX[ibegin:iend,:] = xr
        trY[ibegin:iend] = yr
        trY_gt[ibegin:iend] = yr_gt

        teX[lpw*npte: (lpw+1)*npte,:] = xe
        teY[lpw*npte: (lpw+1)*npte] = ye
        teY_gt[lpw*npte: (lpw+1)*npte] = ye_gt

    maxv = data.max(axis=0)
    minv = data.min(axis=0)
    max_min = maxv - minv
    trX = (trX - minv) / max_min
    teX = (teX - minv) / max_min

    trX[np.isnan(trX)] = 0
    teX[np.isnan(teX)] = 0
    trY, teY = trY.astype(int), teY.astype(int)
    return trX, trY, trY_gt, teX, teY, teY_gt


def score_ap_from_ranks_1 (ranks, nres):
  """ Compute the average precision of one search.
  ranks = ordered list of ranks of true positives
  nres  = total number of positives in dataset  
  """
  
  # accumulate trapezoids in PR-plot
  ap=0.0

  # All have an x-size of:
  recall_step=1.0/nres
  
  for ntp,rank in enumerate(ranks):
    if rank==0: precision_0=1.0
    else:       precision_0=ntp/float(rank)
    precision_1=(ntp+1)/float(rank+1)
    
    ap+=(precision_1+precision_0)*recall_step/2.0
        
  return ap

def mycompute_MAP(Idx, gnd, sdT, topk=None):
    mAP = 0
    num_query = sdT.shape[1]
    mAPall = np.zeros(num_query)
    for i in range(num_query):
        qgnd = gnd[i]
        nres = len(qgnd)
        tp1 = sdT[qgnd, i]
        tp = tp1
        tp = sorted(tp)
        if topk is None:
            ap = score_ap_from_ranks_1(tp, nres)
        else:
            l1 = np.where(tp<=topk)
            ap = score_ap_from_ranks_1(tp[l1], topk)
        mAP += ap
        mAPall[i] = ap
    mAP /= num_query
    return mAP, mAPall

def getMeanAveragePrecision(trX, teX, gnd, bs=100, use_cuda=True):
    # function Mapall = getMeanAveragePrecision(trX,teX,gnd,bs)
    tic = time.time()
    Nte = teX.shape[0]
    Mapall = np.zeros(Nte)
    Nb = math.ceil(Nte/bs)
    with torch.no_grad():
        trX = torch.tensor(trX)
        if use_cuda: trX = trX.cuda()
        for cnt_batch in range(Nb):
            print('\rmean average precision batch: %d in %d' %(cnt_batch,Nb), end="") 
            ibegin = cnt_batch * bs
            iend = ibegin + bs

            xxt = teX[ibegin:iend,:]
            xxt = torch.tensor(xxt)
            if use_cuda: xxt = xxt.cuda()

            xx = torch.sum(xxt.pow(2), 1).reshape(-1, 1)
            cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
            xc = torch.mm(xxt, trX.T)
            z = xx + cc - 2*xc # distances = xx^2+cc^2-2*xx*cc;

            d, Idx = torch.sort(z, dim=1)
            sc, sd = torch.sort(Idx, dim=1, descending=False)

            _, mapall = mycompute_MAP(Idx.cpu().numpy(), 
                                gnd[ibegin:iend], sd.T.cpu().numpy())

            Mapall[ibegin:iend] = mapall
        toc = time.time() 
        return Mapall, toc-tic

def myPCA(allsamples, redim):
    samplemean= np.mean(allsamples)
    xmean = allsamples - samplemean
    sigmao = xmean.T.dot(xmean) 
    [eigvals, eigvecs] = np.linalg.eig(sigmao) # d is diagonal vector of diagonal matrix with eigenvalues
    # selected_index = np.arange(allsamples.shape[1]-1, -1, -1)
    decending_order = eigvals.argsort()[::-1]   
    eigvals = eigvals[decending_order] 
    eigvecs = eigvecs[:,decending_order]
    eigvals = eigvals[:redim]
    base = eigvecs[:,:redim]
    return base, eigvals, samplemean

def myPCA2(allsamples, p):
    # function [base,samplemean,redata] = myPCA2(allsamples,p)
    samplemean = np.mean(allsamples); 
    xmean = allsamples - samplemean

    sigma = xmean * xmean.T;   
    [eigvals, eigvecs] = np.linalg.eig(sigma) # d is diagonal vector of diagonal matrix with eigenvalues

    decending_order = eigvals.argsort()[::-1]   
    dsort = eigvals[decending_order] 
    vsort = eigvecs[:,decending_order]

    # base = bsxfun(@times,dsort(1:p)'.^(-1/2), xmean' * vsort(:,1:p) )
    base = np.power(dsort[:p].T, -1/2) * xmean.T * vsort[:, :p]
    redata = allsamples * base
    return base, samplemean, redata

def KNNtest(k,trX,trY,teX,teY, bs=256, use_cuda=True):
    trX, teX = [torch.tensor(vari) if isinstance(vari, np.ndarray) 
                else vari for vari in [trX, teX]]
    trY, teY = [np.array(vari) if not isinstance(vari, np.ndarray) 
                else vari for vari in [trY, teY]]
    Nb = math.ceil(len(teX)/bs)
    pred_all = np.zeros(len(teX))
    with torch.no_grad():
        if use_cuda: trX = trX.cuda()
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
        for cnt_batch in range(Nb):
            print('\r evaluating... %d in %d ' %(cnt_batch,Nb), end="") 
            ibegin = cnt_batch * bs
            iend = ibegin + bs
            xxt = teX[ibegin:iend,:]
            if use_cuda: xxt = xxt.cuda()
            xx = torch.sum(xxt.pow(2), 1).reshape(-1, 1)
            xc = torch.mm(xxt, trX.T) # (Nte,300)x(Ntr,300).T = (Nte, Ntr)
            z = xx - 2*xc + cc # (Nte,1) - 2*(Nte, Ntr) + (Ntr,1) = (Nte,Ntr)
            d, Idx = torch.sort(z, 1)
            ll = Idx[:,:k].cpu().numpy()
            ly = trY[ll]

            if (k==1): pred = ly
            else:      pred = stats.mode(ly, 1)[0]
            pred_t = np.array(pred).flatten()
            pred_all[ibegin:iend] = pred_t
            # dis = d[:,:max(5,k)]
    acc = np.mean(pred_all == teY)
    print()
    return pred, acc, None, None # dis_all, Idx_all
    
def KNNtest_OLD(k,trX,trY,teX,teY):
    xx = np.sum(np.power(teX, 2), 1).reshape(-1, 1)
    cc = np.sum(np.power(trX, 2), 1).reshape(1, -1)
    xc = teX.dot(trX.T) # (Nte,300)x(Ntr,300).T = (Nte, Ntr)
    z = xx - 2*xc + cc # (Nte,1) - 2*(Nte, Ntr) + (Ntr,1) = (Nte,Ntr)

    d, Idx = torch.sort(torch.Tensor(z), 1)

    ll = Idx[:,:k]
    ly = trY[ll]

    if (k==1):
        pred = ly
    else:
        pred = stats.mode(ly, 1)[0]
    pred = np.array(pred).flatten()

    dis = d[:,:max(5,k)]
    acc = np.mean(pred == teY)
    return pred, acc, dis, Idx

def getbatchKNNindex(trX,k,teX, batchsize=256, use_cuda=True):
    with torch.no_grad():
        trX = torch.tensor(trX)
        teX = torch.tensor(teX)
        if use_cuda:
            trX = trX.cuda()
            teX = teX.cuda()
            
        # trX = trX.astype(np.single)
        # teX = teX.astype(np.single)

        N = trX.shape[0]
        numBatch = math.ceil(N/batchsize)
        xx = torch.sum(teX.pow(2), 1).reshape(-1, 1)
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
        
        Idx = torch.zeros(N, k+1)

        for lpw in range(1, numBatch+1, 1):
            print('\rfinding K-nearest neighbors: batch %d in %d'%(lpw,numBatch), end="")
            ibegin = (lpw-1)*batchsize
            iend = lpw*batchsize
            if (iend > N):
                iend = N
            
            xc = torch.mm(teX[ibegin:iend,:], trX.T)
            z = xx[ibegin:iend] - 2*xc + cc # (bs,1) - 2*(bs, Ntr) + (Ntr,1) = (bs,Ntr)

            # [d,I]= torch.sort(z(:,1:N),2)
            d, I = torch.sort(z, dim=1)
            Idx[ibegin:iend,:] = I[:,:k+1]

        # assert np.alltrue([Idx[i,0] == i for i in range(len(Idx))]) # closest one is itself
        print("\n")
        return Idx.cpu().numpy().astype(int)

        
def getLMNNidx(Y,Idx):
    k = Idx.shape[1]
    numData = len(Y)
    LP1 = []
    LP2 = []
    LN = []

    for i in range(numData):
        Yij = Y[Idx[i,:]] # (k,) label of NNs of datapoint i
        repYi = np.tile(Yij,[k,1]).T.reshape(k*k) # (k*k,) with form: [y1,...y1,y2,...y2,...yk,...yk]
        neiYi = np.tile(Yij,[k,1]).reshape(k*k) # (k*k,) with form:   [y1,...yk,y1,...yk,...,y1,...yk]
        Yijl = (repYi == Y[i]) & (neiYi != Y[i]) # this is: (yij)(1 - yil) = 1
        
        Lij = Idx[i,:] # (k,) index of NNs 
        LP2_t = np.tile(Lij,[k,1]).T.reshape(k*k) # (k*k,)
        LN_t = np.tile(Lij,[k,1]).reshape(k*k) # (k*k,)
        
        LP2_t = LP2_t[Yijl]
        LN_t = LN_t[Yijl]

        LP1_t = np.ones(len(LP2_t))*i
        LP1.extend(LP1_t)
        LP2.extend(LP2_t)
        LN.extend(LN_t)

    # fprintf('\n<<<<<<compute finished>>>>>>>\n')
    LP1 = np.array(LP1, dtype=np.int32)
    LP2 = np.array(LP2, dtype=np.int32)
    LN = np.array(LN, dtype=np.int32)

    return LP1, LP2, LN


def getTripletIndex(trY,y_idx,k):
    lp1_all = []
    lp2_all = []
    ln_all = []
    for lpw in range(len(y_idx)):
        lc = np.where(trY==y_idx[lpw])[0] # position of samples have label y_idx[lpw] 
        n_samples = len(lc) # number of that samples
        k = min(k, n_samples - 1) # number of must-link samples with one sample
        # must-link lp1-lp2
        lp1 = [i for i in range(n_samples)]
        lp1 = np.tile(lp1, [k, 1]).T.reshape(k*n_samples)

        rp = np.random.permutation(n_samples*n_samples) % n_samples
        rp = rp.reshape(n_samples, n_samples)
        rp = rp[:, :k]

        lp2 = rp.T.flatten()
        
        lfc = np.where(lp1!=lp2)[0]
        lp1 = lp1[lfc]
        lp2 = lp2[lfc]

        lp1_all.extend(lc[lp1])
        lp2_all.extend(lc[lp2])
        # cannot-link
        ln = np.where(trY != y_idx[lpw])[0]
        sn1 = np.random.randint(len(ln), size=(len(lp1)))
        ln_all.extend(sn1)
    return lp1_all, lp2_all, ln_all

def symInitA2(dim,mu_scalar,v_scalar,nondiag_scalar):

    triuA = np.triu(np.ones((dim,dim)),0)
    triuA = triuA + np.diag(np.diag(triuA))
    I_triuA = np.where(triuA>0) # row contiguous # matlab col contiguous
    I_diagA = np.where(triuA==2)
    l_nondiagA = np.where(triuA==1)

    Am = np.eye(dim)
    Am[I_diagA] = mu_scalar
    Am[l_nondiagA] = mu_scalar*nondiag_scalar*2
    Am = (Am+Am.T)/2

    mu_0 = Am[I_triuA]

    Av = np.eye(dim)
    Av[I_diagA] = v_scalar
    Av[l_nondiagA] = v_scalar*nondiag_scalar*2
    Av = (Av+Av.T)/2
    v_0 = Av[I_triuA]
    return Am, mu_0, v_0, I_triuA

## functions support updating
def getLambda_(mt, WW):
    # mt (D(D+1)/2, 1)
    # WW (D(D+1)/2, bs)
    # function L = getLambda(mt,WW)
    if len(mt.shape) == 1:
        mt = mt.reshape(1, -1)
    L = 1 + torch.abs(1 - torch.mm(mt, WW).reshape(-1))
    return L

def getApproxCov(X,n):
    # function Sa = getApproxCov(X,n)   
    # X:[D,N]  S = X*X'; Sa->S      n:# of fold       Xm = X/m;   Sm = Xm*Xm'

    [D,N] = X.shape
    m = torch.floor(N/n)

    rp = np.random.permutation(N)
    X = X[:,rp[:n*m]]
    X = X.reshape(D,m,n)
    Xm = torch.mean(X,2)
    mu = torch.mean(Xm,1)

    Sm = torch.mm(Xm, Xm.T)
    Sa = Sm*n*n - n*(n-1)*m*torch.mm(mu, mu.T) 
    return Sa

def solve_pytorch(B, A):
    x, LU = torch.solve(B, A)
    # assert np.allclose(np.dot(A.cpu().numpy(), x.cpu().numpy()), B.cpu().numpy())
    return x    

def solve_numpy(B, A):
    if not isinstance(B, np.ndarray):
        cuda = True
        device = B.device
        B = B.cpu().numpy()
        A = A.cpu().numpy()

    x = np.linalg.solve(A.astype(np.double), B.astype(np.double))
    assert np.allclose(np.dot(A, x), B)
    if cuda:
        x = torch.tensor(x).to(device).reshape(-1, 1)
    return x

def getPosterior(L, WW, pair_px, m0, v0, mt_1, lambda_, rho, flag): 
    #lambda_:in-class rho:step size
    # function [mt] = getPosterior(L,WW,pair_px,m0,v0,mt_1,lambda_,rho,flag #lambda_:in-class rho:step size
    Lt_1 = 1./L

    v0_1 = 1./v0
    Vt_inv = torch.diag(v0_1)

    # standard: A = bsxfun(@times,WW,Lt_1)*WW'
    # WWL = bsxfun(@times,WW,sqrt(Lt_1))
    WWL = WW * torch.sqrt(Lt_1) # (D(D+1)/2, bs)
    if flag.apx == 0:
        A = torch.mm(WWL, WWL.T) # x_ijl * \lambda_ijl^-1 * x_ijl^T    # (D(D+1)/2, D(D+1)/2)
        Ap = torch.mm(pair_px, pair_px.T) # (D(D+1)/2, D(D+1)/2)
        Vt_inv = Vt_inv + A + Ap*torch.mean(Lt_1)
        Vt_inv = Vt_inv + A + Ap*lambda_ # *mean(Lt_1)
    else:
    # approximating the Vt
        # WWL = bsxfun(@times,WW,sqrt(Lt_1))
        WWL = WW * torch.sqrt(Lt_1)
        A1 = getApproxCov(WWL,flag.fold) 
        Ap1 = getApproxCov(pair_px,flag.fold) 
        Vt_inv = Vt_inv + A1 + Ap1*lambda_; #*mean(Lt_1)

    # mt = double(Vt_inv)\double(v0_1.*m0 + WW*(1 + Lt_1'))
    B = (v0_1*m0).unsqueeze(1) + torch.mm(WW, 1 + Lt_1.unsqueeze(1)) # (D(D+1)/2, D(D+1)/2)
    # mt = np.linalg.solve(Vt_inv.astype(np.double), B.astype(np.double))
    mt = solve_pytorch(B, Vt_inv)
    # assert np.allclose(np.dot(Vt_inv, mt), B)
    mt = mt_1.unsqueeze(1)*(1-rho) + mt*rho

    # tic;Vt = inv(Vt_inv);toc
    # tic;mt = Vt*(V0^-1*m0 + WW*(1 + Lt_1'));toc
    # Vt = Vt_1*(1-rho) + Vt*rho
    return mt.squeeze()

def evaluate(mu, s0trX, s0teX, trY, teY, I_triuA, option):
    dim = option.dim
    KNN_te = option.KNN_te
    # reconstruct matrix A
    A_blmnn = np.zeros((dim, dim))
    A_blmnn[I_triuA] = mu
    A_blmnn = (A_blmnn + A_blmnn.T)/2
    # tapering (clipping negative eigenvalues to zero)
    [D, V] = np.linalg.eig(A_blmnn)
    D = np.diag(D)
    if (sum(np.diag(D) < 0) > 0):
        print('tapering\n')
        D[np.where(D < 0)] = 0
        A_blmnn = V.dot(D.dot(V.T))
    # transform
    Tt = V.dot(np.sqrt(D))
    s0trX_blmnn = s0trX.dot(Tt)
    s0teX_blmnn = s0teX.dot(Tt)

    if option.mode == 'acc':
        [_, acc0, _, _] = KNNtest(KNN_te, s0trX_blmnn, trY, s0teX_blmnn, teY)
        return acc0
    elif option.mode == 'map': 
        Map_A0, ttt = getMeanAveragePrecision(s0trX_blmnn, s0teX_blmnn, option.gnd, bs=256, use_cuda=option.use_cuda)
        return Map_A0, ttt
