from itertools import combinations
import matplotlib.pyplot as plt

import numpy as np
import torch
import h5py
import os, sys, time, logging
import math
from scipy import stats
from sklearn.decomposition import PCA, FastICA
from MulticoreTSNE import MulticoreTSNE as TSNE
import imageio


def get_logger(fname, savefolder):
    flogger = logging.getLogger(name=fname)
    flogger.setLevel(logging.INFO)
    flogger.propagate = False # not propagate to root flogger (print to sdtout)
    f_handler = logging.FileHandler(os.path.join(savefolder, fname + ".log"))
    f_format = logging.Formatter('[%(asctime)s] - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
    f_handler.setFormatter(f_format)
    flogger.addHandler(f_handler)
    return flogger


def pdist(vectors):
    distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
        dim=1).view(-1, 1)
    return distance_matrix

def Wrapper_PCA(trX, valX=None, teX=None, dim=2):
    pca = PCA(n_components=dim)
    pca.fit(trX)
    trX_emb = pca.transform(trX)
    if valX is not None:
        valX_emb = pca.transform(valX)
    else: valX_emb = None
    if teX is not None:
        teX_emb = pca.transform(teX)
    else: teX_emb = None
    return trX_emb, valX_emb, teX_emb

def cprint(color, text, **kwargs):
    if color[0] == '*':
        pre_code = '1;'
        color = color[1:]
    else:
        pre_code = ''
    code = {
        'a': '30',
        'r': '31',
        'g': '32',
        'y': '33',
        'b': '34',
        'p': '35',
        'c': '36',
        'w': '37'
    }
    print("\x1b[%s%sm%s\x1b[0m" % (pre_code, code[color], text), **kwargs)
    sys.stdout.flush()


def getCleanTe(data, labels, labels_gt, npte):
    numClasses = max(labels) + 1

    ntr = data.shape[0] - numClasses*npte
    nte = numClasses*npte

    trX = np.zeros((ntr, data.shape[1]))
    teX = np.zeros((nte, data.shape[1]))

    trY = np.zeros(ntr)
    teY = np.zeros(nte)
    trY_gt = np.zeros(ntr)
    teY_gt = np.zeros(nte)
    iend = 0
    for lpw in range(numClasses):
        # %     fprintf('class %d in %d\n',lpw,numClasses)
        lt = np.argwhere(labels == lpw)[:, 0]
        np.random.shuffle(lt)

        curX = data[lt, :]
        curY = labels[lt]
        curY_gt = labels_gt[lt]
        lte = np.argwhere(curY==curY_gt)[:, 0]
        
        if(len(lte) < npte):
            print("Not enough data for testing")
            break
        
        ltr = np.ones(len(lt))
        ltr[lte[:npte]] = 0  # not pick test samples
        ltr = ltr.astype(bool)
        
        xr = curX[ltr,:]
        yr = curY[ltr]
        yr_gt = curY_gt[ltr]
        
        xe = curX[lte[:npte],:]
        ye = curY[lte[:npte]]
        ye_gt = curY_gt[lte[:npte]]

        ibegin = iend
        iend = ibegin + len(yr)
        trX[ibegin:iend,:] = xr
        trY[ibegin:iend] = yr
        trY_gt[ibegin:iend] = yr_gt

        teX[lpw*npte: (lpw+1)*npte,:] = xe
        teY[lpw*npte: (lpw+1)*npte] = ye
        teY_gt[lpw*npte: (lpw+1)*npte] = ye_gt

    maxv = data.max(axis=0)
    minv = data.min(axis=0)
    max_min = maxv - minv
    trX = (trX - minv) / max_min
    teX = (teX - minv) / max_min

    trX[np.isnan(trX)] = 0
    teX[np.isnan(teX)] = 0
    return trX, trY, trY_gt, teX, teY, teY_gt


def score_ap_from_ranks_1 (ranks, nres):
  """ Compute the average precision of one search.
  ranks = ordered list of ranks of true positives
  nres  = total number of positives in dataset  
  """
  
  # accumulate trapezoids in PR-plot
  ap=0.0

  # All have an x-size of:
  recall_step=1.0/nres
  
  for ntp,rank in enumerate(ranks):
    if rank==0: precision_0=1.0
    else:       precision_0=ntp/float(rank)
    precision_1=(ntp+1)/float(rank+1)
    
    ap+=(precision_1+precision_0)*recall_step/2.0
        
  return ap

def mycompute_MAP(Idx, gnd, sdT, topk=None):
    mAP = 0
    num_query = sdT.shape[1]
    mAPall = np.zeros(num_query)
    for i in range(num_query):
        qgnd = gnd[i]
        nres = len(qgnd)
        tp1 = sdT[qgnd, i]
        tp = tp1
        tp = sorted(tp)
        if topk is None:
            ap = score_ap_from_ranks_1(tp, nres)
        else:
            l1 = np.where(tp<=topk)
            ap = score_ap_from_ranks_1(tp[l1], topk)
        mAP += ap
        mAPall[i] = ap
    mAP /= num_query
    return mAP, mAPall

def getMeanAveragePrecision(trX, teX, gnd, bs=100, use_cuda=True):
    # function Mapall = getMeanAveragePrecision(trX,teX,gnd,bs)
    tic = time.time()
    Nte = teX.shape[0]
    Mapall = np.zeros(Nte)
    Nb = math.ceil(Nte/bs)
    with torch.no_grad():
        trX = torch.tensor(trX)
        if use_cuda: trX = trX.cuda()
        for cnt_batch in range(Nb):
            print('\rmean average precision batch: %d in %d' %(cnt_batch,Nb), end="") 
            ibegin = cnt_batch * bs
            iend = ibegin + bs

            xxt = teX[ibegin:iend,:]
            xxt = torch.tensor(xxt)
            if use_cuda: xxt = xxt.cuda()

            xx = torch.sum(xxt.pow(2), 1).reshape(-1, 1)
            cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
            xc = torch.mm(xxt, trX.T)
            z = xx + cc - 2*xc # distances = xx^2+cc^2-2*xx*cc;

            d, Idx = torch.sort(z, dim=1)
            sc, sd = torch.sort(Idx, dim=1, descending=False)

            _, mapall = mycompute_MAP(Idx.cpu().numpy(), 
                                gnd[ibegin:iend], sd.T.cpu().numpy())

            Mapall[ibegin:iend] = mapall
        toc = time.time() 
        return Mapall, toc-tic


def myPCA(allsamples, redim):
    # function [base,d,samplemean] = myPCA(allsamples, redim)
    # fprintf('PCA dimension reduction!\n')
    # allsamples = double(allsamples)
    samplemean= np.mean(allsamples)
    xmean = allsamples - samplemean
    sigmao = xmean.T.dot(xmean) 
    [eigvals, eigvecs] = np.linalg.eig(sigmao) # d is diagonal vector of diagonal matrix with eigenvalues
    # selected_index = np.arange(allsamples.shape[1]-1, -1, -1)
    decending_order = eigvals.argsort()[::-1]   
    eigvals = eigvals[decending_order] 
    eigvecs = eigvecs[:,decending_order]
    eigvals = eigvals[:redim]
    base = eigvecs[:,:redim]
    return base, eigvals, samplemean

def load_data_from_mat(path="data/mnist_fea_pca.mat"):
    f = h5py.File(path,'r')
    trX = f.get('trX')
    trY = f.get('trY')
    teX = f.get('teX')
    teY = f.get('teY')
    trX, trY = np.array(trX).T, np.array(trY)[0]
    teX, teY = np.array(teX).T, np.array(teY)[0]
    trY, teY = trY.astype(int), teY.astype(int)
    return trX, trY, teX, teY   

def load_celeb_from_mat(path="data/celeb_fea.mat"):
    f = h5py.File(path,'r')
    data = f.get('data')
    labels = f.get('labels')
    labels_gt = f.get('labels_gt')

    data = np.array(data).T
    labels = np.array(labels)[0].astype(int)
    labels_gt = np.array(labels_gt)[0].astype(int)

    return data, labels, labels_gt

def load_cifar10_from_mat(path):
    f = h5py.File(path,'r')
    trX = f.get('trainXCp')['value'][:]
    teX = f.get('testXCp')['value'][:]
    trY = f.get('trainY')['value'][:]
    teY = f.get('testY')['value'][:]

    trX, trY = np.array(trX).T, np.array(trY)[0]
    teX, teY = np.array(teX).T, np.array(teY)[0]
    trX, teX = trX.astype(np.float64), teX.astype(np.float64)
    trY, teY = trY.astype(int), teY.astype(int)
    return trX, trY, teX, teY
    

def randomSampleData(data, label, perN):
    sdata = []
    slabel = []
    classes = np.unique(label)
    for lpw in classes:
        l1 = np.argwhere(label==lpw)[:, 0]
        lp = np.random.permutation(len(l1))
        l1 = l1[lp]
        sdata.extend(data[l1[:perN[lpw-1]],:])
        slabel.extend(label[l1[:perN[lpw-1]]])
    sdata = np.array(sdata)
    slabel = np.array(slabel)
    return sdata, slabel


def randomStratifiedSampleData(label, perN):
    selected_indices = []
    classes = set(label)

    if max(classes) == len(classes):
        print("LabelID starts from 1!")
        zero_index = False
    else:
        zero_index = True
        print("LabelID starts from 0!")

    for lpw in range(len(classes)):
        if zero_index:
            l1 = np.argwhere(label == lpw)[:, 0]
        else:
            l1 = np.argwhere(label == (lpw+1))[:, 0]
        
        lp = np.random.permutation(len(l1))
        l1 = l1[lp]
        selected_indices.extend(l1[:perN[lpw]])

    selected_indices = np.array(selected_indices)
    return selected_indices


def generateRandomLabelNoise2(TtrY, noiserate, plus_one=True):
    if not isinstance(TtrY, np.ndarray):
        TtrY = TtrY.numpy()

    classes = np.unique(TtrY)
    numClasses = len(classes)
    trY = TtrY.copy()
    if noiserate < 1e-4:
        return trY

    for lpw in classes:
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserate).astype(int)
        lu = np.random.random_integers(2, numClasses - 1, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (TtrY[la[rnla[:N1]]] + lu) % numClasses
        if plus_one: # for class starts with 1
            trY[la[rnla[:N1]]] += 1
        
    return trY


def generateImBalancedRandomLabelNoise2(TtrY, noiserates):
    if not isinstance(TtrY, np.ndarray):
        TtrY = TtrY.numpy()

    classes = np.unique(TtrY)
    numClasses = len(classes)
    trY = TtrY.copy()

    for lpw in classes:
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserates[lpw]).astype(int)
        lu = np.random.random_integers(2, numClasses - 1, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (TtrY[la[rnla[:N1]]] + lu) % numClasses
        
    return trY

def mkdir(paths):
    if not isinstance(paths, (list, tuple)):
        paths = [paths]
    for path in paths:
        if not os.path.isdir(path):
            os.makedirs(path)


def getbatchKNNindex(trX,k,teX, batchsize=256, use_cuda=True):
    with torch.no_grad():
        trX = torch.tensor(trX)
        teX = torch.tensor(teX)
        if use_cuda:
            trX = trX.cuda()
            teX = teX.cuda()

        N = trX.shape[0]
        numBatch = math.ceil(N/batchsize)
        xx = torch.sum(teX.pow(2), 1).reshape(-1, 1)
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
        
        Idx = torch.zeros(N, k+1)

        for lpw in range(1, numBatch+1, 1):
            print('\rfinding K-nearest neighbors: batch %d in %d'%(lpw,numBatch), end="")
            ibegin = (lpw-1)*batchsize
            iend = lpw*batchsize
            if (iend > N):
                iend = N
            
            xc = torch.mm(teX[ibegin:iend,:], trX.T)
            z = xx[ibegin:iend] - 2*xc + cc # (bs,1) - 2*(bs, Ntr) + (Ntr,1) = (bs,Ntr)

            # [d,I]= torch.sort(z(:,1:N),2)
            d, I = torch.sort(z, dim=1)
            Idx[ibegin:iend,:] = I[:,:k+1]

        # assert np.alltrue([Idx[i,0] == i for i in range(len(Idx))]) # closest one is itself
        print("\n")
        return Idx.cpu().numpy().astype(int)

        
def getLMNNidx(Y,Idx):
    if not isinstance(Y, np.ndarray):
        Y = Y.numpy() 
    # function [LP1,LP2,LN] = getLMNNidx(Y,Idx)
    k = Idx.shape[1]
    numData = len(Y)
    LP1 = []
    LP2 = []
    LN = []
    for i in range(numData):
        Yij = Y[Idx[i,:]] # (k,) label of NNs of datapoint i
        repYi = np.tile(Yij,[k,1]).T.reshape(k*k) # (k*k,) with form: [y1,...y1,y2,...y2,...yk,...yk]
        neiYi = np.tile(Yij,[k,1]).reshape(k*k) # (k*k,) with form:   [y1,...yk,y1,...yk,...,y1,...yk]
        Yijl = (repYi == Y[i]) & (neiYi != Y[i]) # this is: (yij)(1 - yil) = 1
        
        Lij = Idx[i,:] # (k,) index of NNs 
        LP2_t = np.tile(Lij,[k,1]).T.reshape(k*k) # (k*k,)
        LN_t = np.tile(Lij,[k,1]).reshape(k*k) # (k*k,)
        
        LP2_t = LP2_t[Yijl]
        LN_t = LN_t[Yijl]

        LP1_t = np.ones(len(LP2_t))*i
        LP1.extend(LP1_t)
        LP2.extend(LP2_t)
        LN.extend(LN_t)

    LP1 = np.array(LP1, dtype=np.int32)
    LP2 = np.array(LP2, dtype=np.int32)
    LN = np.array(LN, dtype=np.int32)

    return LP1, LP2, LN

def find_KNN_triplets(dataset, targets, KNN_triplet=21):
    Idx_all = getbatchKNNindex(dataset, 100, dataset)
    Idx = Idx_all[:, 1:KNN_triplet + 1] # first point is itself
    lp1_all, lp2_all, ln_all = getLMNNidx(targets, Idx)
    LPN = [lp1_all, lp2_all, ln_all]
    triplets_indices = [[i, j, l] for i,j,l in zip(lp1_all, lp2_all, ln_all)]
    return triplets_indices
    
def get_all_embedding_from_dataloader(dataloader, feature_size):
    embedding_test = np.zeros((len(dataloader.dataset), feature_size))
    curr_idx = 0
    for x_batch, _ in dataloader:
        embedding_test[curr_idx: curr_idx+len(x_batch)] = x_batch.view(-1, feature_size).numpy()
        curr_idx += len(x_batch)
    return embedding_test

def plot_embeddings(embeddings, targets, xlim=None, ylim=None, epoch=0, path=None):
    classes = set(targets)
    
    fig, ax = plt.subplots(figsize=(10,10))
    for classidx in classes:
        inds = np.where(targets==classidx)[0]
        ax.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5)
    if xlim:
        ax.xlim(xlim[0], xlim[1])
    if ylim:
        ax.ylim(ylim[0], ylim[1])

    ax.set(title = "Epoch: {}".format(epoch))
    ax.legend(classes)
    if path is not None:
        plt.savefig(path)
        plt.close(fig)
        return None
    else:
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)
        return image

def plot_highest_var_dim(embeddings, targets, epoch=0, tnse=True):
    if tnse:
        X_embedded = TSNE(n_components=2, n_jobs=4).fit_transform(embeddings)
    else:
        std = embeddings.std(axis=0)
        std = torch.tensor(std)
        _, indices = torch.topk(std, 2, dim=0, largest=True, sorted=False)
        indices = indices.numpy()
        X_embedded = embeddings[:, indices]
    return plot_embeddings(X_embedded, targets, epoch=epoch)


def extract_embeddings(dataloader, model, embed_dim, cuda=True):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), embed_dim))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            if cuda: images = images.cuda()
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels

def KNNtest(k,trX,trY,teX,teY, bs=256, use_cuda=True):
    trX, teX = [torch.tensor(vari) if isinstance(vari, np.ndarray) 
                else vari for vari in [trX, teX]]
    trY, teY = [np.array(vari) if not isinstance(vari, np.ndarray) 
                else vari for vari in [trY, teY]]
    Nb = math.ceil(len(teX)/bs)
    pred_all = np.zeros(len(teX))
    with torch.no_grad():
        if use_cuda: trX = trX.cuda()
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
        for cnt_batch in range(Nb):
            # print('\r evaluating... %d in %d ' %(cnt_batch,Nb), end="") 
            ibegin = cnt_batch * bs
            iend = ibegin + bs
            xxt = teX[ibegin:iend,:]
            if use_cuda: xxt = xxt.cuda()
            xx = torch.sum(xxt.pow(2), 1).reshape(-1, 1)
            xc = torch.mm(xxt, trX.T) # (Nte,300)x(Ntr,300).T = (Nte, Ntr)
            # z = bsxfun(@plus, cc, bsxfun(@minus, xx, 2*xc)); % distances = xx^2+cc^2-2*xx*cc
            z = xx - 2*xc + cc # (Nte,1) - 2*(Nte, Ntr) + (Ntr,1) = (Nte,Ntr)
            d, Idx = torch.sort(z, 1)
            ll = Idx[:,:k].cpu().numpy()
            ly = trY[ll]

            if (k==1): pred = ly
            else:      pred = stats.mode(ly, 1)[0]
            pred_t = np.array(pred).flatten()
            pred_all[ibegin:iend] = pred_t
            # dis = d[:,:max(5,k)]
    acc = np.mean(pred_all == teY)
    # print()
    return pred, acc, None, None # dis_all, Idx_all


if __name__ == '__main__':
    # Rebuttal_agg()
    DeepNN_agg()