import os, logging
import numpy as np
import torch
from scipy import stats
import time
from models.priors import log_gauss_approximation
from matplotlib import pyplot as plt
import h5py
import imageio
import pickle as pkl 
import math
from sklearn.decomposition import PCA, FastICA
# from sklearn.manifold import TSNE
from MulticoreTSNE import MulticoreTSNE as TSNE

def get_logger(fname, savefolder):
    flogger = logging.getLogger(name=fname)
    flogger.setLevel(logging.INFO)
    flogger.propagate = False # not propagate to root flogger (print to sdtout)
    f_handler = logging.FileHandler(os.path.join(savefolder, fname + ".log"))
    f_format = logging.Formatter('[%(asctime)s] - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
    f_handler.setFormatter(f_format)
    flogger.addHandler(f_handler)
    return flogger

def Wrapper_FastICA(trX, teX, dim=2):
    ica = FastICA(n_components=dim)
    ica.fit(trX)
    trX_emb = ica.transform(trX)
    teX_emb = ica.transform(teX)
    return trX_emb, teX_emb

def Wrapper_PCA(trX, valX=None, teX=None, dim=2):
    pca = PCA(n_components=dim)
    pca.fit(trX)
    trX_emb = pca.transform(trX)
    if valX is not None:
        valX_emb = pca.transform(valX)
    else: valX_emb = None
    if teX is not None:
        teX_emb = pca.transform(teX)
    else: teX_emb = None
    return trX_emb, valX_emb, teX_emb

def write_bytes(obj, fpath):
    with open(fpath, 'wb') as f:
        pkl.dump(obj, f)
        
def load_bytes(fpath):
    with open(fpath, 'rb') as f:
        obj = pkl.load(f)
    return obj 

def load_mnist_from_mat(path="data/mnist_fea_pca.mat"):
	f = h5py.File(path,'r')
	trX = f.get('trX')
	trY = f.get('trY')
	teX = f.get('teX')
	teY = f.get('teY')
	trX, trY = np.array(trX).T, np.array(trY)[0]
	teX, teY = np.array(teX).T, np.array(teY)[0]
	trY, teY = trY.astype(int), teY.astype(int)
	return trX, trY, teX, teY

def load_celeb_from_mat(path="data/celeb_fea.mat"):
    f = h5py.File(path,'r')
    data = f.get('data')
    labels = f.get('labels')
    labels_gt = f.get('labels_gt')

    data = np.array(data).T
    labels = np.array(labels)[0].astype(int)
    labels_gt = np.array(labels_gt)[0].astype(int)

    return data, labels, labels_gt

def load_mat_file(path):
    from scipy.io import loadmat
    data_dict = loadmat(path) 
    trX, trY = data_dict['trX'], data_dict['trY']
    teX, teY = data_dict['teX'], data_dict['teY']
    return trX, trY, teX, teY

def save_mat_file(path, obj):
    '''
    obj: dict type
    '''
    from scipy.io import savemat
    print("saving {} ...".format(path))
    savemat(path, obj, do_compression=False)

def load_cifar10_from_mat(path):
    f = h5py.File(path,'r')
    trX = f.get('trainXCp')['value'][:]
    teX = f.get('testXCp')['value'][:]
    trY = f.get('trainY')['value'][:]
    teY = f.get('testY')['value'][:]

    trX, trY = np.array(trX).T, np.array(trY)[0]
    teX, teY = np.array(teX).T, np.array(teY)[0]
    trX, teX = trX.astype(np.float64), teX.astype(np.float64)
    trY, teY = trY.astype(int), teY.astype(int)
    return trX, trY, teX, teY

def save_h5(path, obj):
    print("saving {} ...".format(path))
    hf = h5py.File(path, 'w')
    for k,v in obj.items():
        hf.create_dataset(k, data=v)
    hf.close()

def getCleanTe(data, labels, labels_gt, npte):
    numClasses = max(labels) + 1

    ntr = data.shape[0] - numClasses*npte
    nte = numClasses*npte

    trX = np.zeros((ntr, data.shape[1]))
    teX = np.zeros((nte, data.shape[1]))

    trY = np.zeros(ntr)
    teY = np.zeros(nte)
    trY_gt = np.zeros(ntr)
    teY_gt = np.zeros(nte)
    iend = 0
    for lpw in range(numClasses):
        # %     fprintf('class %d in %d\n',lpw,numClasses)
        lt = np.argwhere(labels == lpw)[:, 0]
        np.random.shuffle(lt)

        curX = data[lt, :]
        curY = labels[lt]
        curY_gt = labels_gt[lt]
        lte = np.argwhere(curY==curY_gt)[:, 0]
        
        if(len(lte) < npte):
            print("Not enough data for testing")
            break
        
        ltr = np.ones(len(lt))
        ltr[lte[:npte]] = 0  # not pick test samples
        ltr = ltr.astype(bool)
        
        xr = curX[ltr,:]
        yr = curY[ltr]
        yr_gt = curY_gt[ltr]
        
        xe = curX[lte[:npte],:]
        ye = curY[lte[:npte]]
        ye_gt = curY_gt[lte[:npte]]

        ibegin = iend
        iend = ibegin + len(yr)
        trX[ibegin:iend,:] = xr
        trY[ibegin:iend] = yr
        trY_gt[ibegin:iend] = yr_gt

        teX[lpw*npte: (lpw+1)*npte,:] = xe
        teY[lpw*npte: (lpw+1)*npte] = ye
        teY_gt[lpw*npte: (lpw+1)*npte] = ye_gt

    maxv = data.max(axis=0)
    minv = data.min(axis=0)
    max_min = maxv - minv
    trX = (trX - minv) / max_min
    teX = (teX - minv) / max_min

    trX[np.isnan(trX)] = 0
    teX[np.isnan(teX)] = 0
    return trX, trY, trY_gt, teX, teY, teY_gt


def score_ap_from_ranks_1 (ranks, nres):
  """ Compute the average precision of one search.
  ranks = ordered list of ranks of true positives
  nres  = total number of positives in dataset  
  """
  ap=0.0

  # All have an x-size of:
  recall_step=1.0/nres
  
  for ntp,rank in enumerate(ranks):
    if rank==0: precision_0=1.0
    else:       precision_0=ntp/float(rank)
    precision_1=(ntp+1)/float(rank+1)
    
    ap+=(precision_1+precision_0)*recall_step/2.0
        
  return ap

def mycompute_MAP(Idx, gnd, sdT, topk=None):
    mAP = 0
    num_query = sdT.shape[1]
    mAPall = np.zeros(num_query)
    for i in range(num_query):
        qgnd = gnd[i]
        nres = len(qgnd)
        tp1 = sdT[qgnd, i]
        tp = tp1
        tp = sorted(tp)
        if topk is None:
            ap = score_ap_from_ranks_1(tp, nres)
        else:
            l1 = np.where(tp<=topk)
            ap = score_ap_from_ranks_1(tp[l1], topk)
        mAP += ap
        mAPall[i] = ap
    mAP /= num_query
    return mAP, mAPall

def getMeanAveragePrecision(trX, teX, gnd, bs=100, use_cuda=True):
    # function Mapall = getMeanAveragePrecision(trX,teX,gnd,bs)
    tic = time.time()
    Nte = teX.shape[0]
    Mapall = np.zeros(Nte)
    Nb = math.ceil(Nte/bs)
    with torch.no_grad():
        trX = torch.tensor(trX)
        if use_cuda: trX = trX.cuda()
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)

        for cnt_batch in range(Nb):
            print('\rmean average precision batch: %d in %d' %(cnt_batch,Nb), end="") 
            ibegin = cnt_batch * bs
            iend = ibegin + bs

            xxt = teX[ibegin:iend,:]
            xxt = torch.tensor(xxt)
            if use_cuda: xxt = xxt.cuda()

            xx = torch.sum(xxt.pow(2), 1).reshape(-1, 1)
            xc = torch.mm(xxt, trX.T)
            z = xx + cc - 2*xc # distances = xx^2+cc^2-2*xx*cc;

            d, Idx = torch.sort(z, dim=1)
            sc, sd = torch.sort(Idx, dim=1, descending=False)

            _, mapall = mycompute_MAP(Idx.cpu().numpy(), 
                                gnd[ibegin:iend], sd.T.cpu().numpy())

            Mapall[ibegin:iend] = mapall
        toc = time.time() 
        return Mapall, toc-tic

def randomStratifiedSampleData(label, perN):
    selected_indices = []
    classes = set(label)
    max_class = int(max(classes))
    for lpw in range(max_class + 1):
        l1 = np.argwhere(label == lpw)[:, 0]
        lp = np.random.permutation(len(l1))
        l1 = l1[lp]
        selected_indices.extend(l1[:perN[lpw]])

    selected_indices = np.array(selected_indices)
    return selected_indices

def generateRandomLabelNoise2(TtrY, noiserate):
    if not isinstance(TtrY, np.ndarray):
        TtrY = TtrY.numpy()

    classes = np.unique(TtrY)
    numClasses = len(classes)
    trY = TtrY.copy()
    if noiserate < 1e-4:
        return trY

    for lpw in classes:
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserate).astype(int)
        lu = np.random.random_integers(2, numClasses - 1, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (TtrY[la[rnla[:N1]]] + lu) % numClasses
        
    return trY

def generateImBalancedRandomLabelNoise2(TtrY, noiserates):
    if not isinstance(TtrY, np.ndarray):
        TtrY = TtrY.numpy()

    classes = np.unique(TtrY)
    numClasses = len(classes)
    trY = TtrY.copy()

    for lpw in classes:
        la = np.argwhere(TtrY==lpw)[:, 0]
        Nla = len(la)
        rnla = np.random.permutation(Nla)
        N1 = np.floor(len(la)*noiserates[lpw]).astype(int)
        lu = np.random.random_integers(2, numClasses - 1, size=(N1)) #-2 completely
        trY[la[rnla[:N1]]] = (TtrY[la[rnla[:N1]]] + lu) % numClasses
        
    return trY

def myPCA(allsamples, redim):
    # function [base,d,samplemean] = myPCA(allsamples, redim)
    # fprintf('PCA dimension reduction!\n')
    # allsamples = double(allsamples)
    samplemean= np.mean(allsamples)
    xmean = allsamples - samplemean
    sigmao = xmean.T.dot(xmean) 
    [eigvals, eigvecs] = np.linalg.eig(sigmao) # d is diagonal vector of diagonal matrix with eigenvalues
    decending_order = eigvals.argsort()[::-1]   
    eigvals = eigvals[decending_order] 
    eigvecs = eigvecs[:,decending_order]
    eigvals = eigvals[:redim]
    base = eigvecs[:,:redim]
    return base, eigvals, samplemean

def myPCA2(allsamples, p):
    samplemean = np.mean(allsamples); 
    xmean = allsamples - samplemean

    sigma = xmean.dot(xmean.T);   
    [eigvals, eigvecs] = np.linalg.eig(sigma) # d is diagonal vector of diagonal matrix with eigenvalues

    decending_order = eigvals.argsort()[::-1]   
    dsort = eigvals[decending_order] 
    vsort = eigvecs[:,decending_order]

    base = np.power(dsort[:p].T, -1/2) * xmean.T * vsort[:, :p]
    redata = allsamples * base
    return base, samplemean, redata

def KNNtest(k,trX,trY,teX,teY, bs=256, use_cuda=True):
    trX, teX = [torch.tensor(vari) if isinstance(vari, np.ndarray) 
                else vari for vari in [trX, teX]]
    trY, teY = [np.array(vari) if not isinstance(vari, np.ndarray) 
                else vari for vari in [trY, teY]]
    Nb = math.ceil(len(teX)/bs)
    pred_all = np.zeros(len(teX))
    with torch.no_grad():
        if use_cuda: trX = trX.cuda()
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
        for cnt_batch in range(Nb):
            # print('\r evaluating... %d in %d ' %(cnt_batch,Nb), end="") 
            ibegin = cnt_batch * bs
            iend = ibegin + bs
            xxt = teX[ibegin:iend,:]
            if use_cuda: xxt = xxt.cuda()
            xx = torch.sum(xxt.pow(2), 1).reshape(-1, 1)
            xc = torch.mm(xxt, trX.T) # (Nte,300)x(Ntr,300).T = (Nte, Ntr)
            z = xx - 2*xc + cc # (Nte,1) - 2*(Nte, Ntr) + (Ntr,1) = (Nte,Ntr)
            d, Idx = torch.sort(z, 1)
            ll = Idx[:,:k].cpu().numpy()
            ly = trY[ll]

            if (k==1): pred = ly
            else:      pred = stats.mode(ly, 1)[0]
            pred_t = np.array(pred).flatten()
            pred_all[ibegin:iend] = pred_t
            # dis = d[:,:max(5,k)]
    acc = np.mean(pred_all == teY)
    # print()
    return pred, acc, None, None # dis_all, Idx_all

def get_embedding(net, traindataset, ensemble=3):
    acct = 0.
    if ensemble == 0:
        embed = net.get_embeddings(traindataset, sample=False).cpu().detach().numpy()
    else:
        embed = 0 
        for i in range(ensemble):
            embed += net.get_embeddings(traindataset, sample=True).cpu().detach().numpy()
        embed /= 1.* ensemble
    return embed

def getbatchKNNindex(trX,k,teX, batchsize=256, use_cuda=True):
    with torch.no_grad():
        trX = torch.tensor(trX)
        teX = torch.tensor(teX)
        if use_cuda:
            trX = trX.cuda()
            teX = teX.cuda()

        N = trX.shape[0]
        numBatch = math.ceil(N/batchsize)
        xx = torch.sum(teX.pow(2), 1).reshape(-1, 1)
        cc = torch.sum(trX.pow(2), 1).reshape(1, -1)
        
        Idx = torch.zeros(N, k+1)

        for lpw in range(1, numBatch+1, 1):
            print('\rfinding K-nearest neighbors: batch %d in %d'%(lpw,numBatch), end="")
            ibegin = (lpw-1)*batchsize
            iend = lpw*batchsize
            if (iend > N):
                iend = N
            
            xc = torch.mm(teX[ibegin:iend,:], trX.T)
            z = xx[ibegin:iend] - 2*xc + cc # (bs,1) - 2*(bs, Ntr) + (Ntr,1) = (bs,Ntr)

            # [d,I]= torch.sort(z(:,1:N),2)
            d, I = torch.sort(z, dim=1)
            Idx[ibegin:iend,:] = I[:,:k+1]

        # assert np.alltrue([Idx[i,0] == i for i in range(len(Idx))]) # closest one is itself
        print("\n")
        return Idx.cpu().numpy().astype(int)

def getLMNNidx(Y,Idx):
    if not isinstance(Y, np.ndarray):
        Y = Y.numpy() 
    k = Idx.shape[1]
    numData = len(Y)
    LP1 = []
    LP2 = []
    LN = []

    for i in range(numData):
        Yij = Y[Idx[i,:]] # (k,) label of NNs of datapoint i
        repYi = np.tile(Yij,[k,1]).T.reshape(k*k) # (k*k,) with form: [y1,...y1,y2,...y2,...yk,...yk]
        neiYi = np.tile(Yij,[k,1]).reshape(k*k) # (k*k,) with form:   [y1,...yk,y1,...yk,...,y1,...yk]
        Yijl = (repYi == Y[i]) & (neiYi != Y[i]) # this is: (yij)(1 - yil) = 1
        
        Lij = Idx[i,:] # (k,) index of NNs 
        LP2_t = np.tile(Lij,[k,1]).T.reshape(k*k) # (k*k,)
        LN_t = np.tile(Lij,[k,1]).reshape(k*k) # (k*k,)
        
        LP2_t = LP2_t[Yijl]
        LN_t = LN_t[Yijl]

        LP1_t = np.ones(len(LP2_t))*i
        LP1.extend(LP1_t)
        LP2.extend(LP2_t)
        LN.extend(LN_t)

    # fprintf('\n<<<<<<compute finished>>>>>>>\n')
    LP1 = np.array(LP1, dtype=np.int32)
    LP2 = np.array(LP2, dtype=np.int32)
    LN = np.array(LN, dtype=np.int32)

    return LP1, LP2, LN

def find_KNN_triplets(dataset, targets, KNN_triplet=21):
    Idx_all = getbatchKNNindex(dataset, 100, dataset)
    Idx = Idx_all[:, 1:KNN_triplet + 1] # first point is itself
    lp1_all, lp2_all, ln_all = getLMNNidx(targets, Idx)
    LPN = [lp1_all, lp2_all, ln_all]
    triplets_indices = [[i, j, l] for i,j,l in zip(lp1_all, lp2_all, ln_all)]
    return triplets_indices

def getTripletIndex(trY,y_idx,k):
    lp1_all = []
    lp2_all = []
    ln_all = []
    for lpw in range(len(y_idx)):
        lc = np.where(trY==y_idx[lpw])[0] # position of samples have label y_idx[lpw] 
        n_samples = len(lc) # number of that samples
        k = min(k, n_samples - 1) # number of must-link samples with one sample
        # must-link lp1-lp2
        lp1 = [i for i in range(n_samples)]
        lp1 = np.tile(lp1, [k, 1]).T.reshape(k*n_samples)

        rp = np.random.permutation(n_samples*n_samples) % n_samples
        rp = rp.reshape(n_samples, n_samples)
        rp = rp[:, :k]

        lp2 = rp.T.flatten()
        
        lfc = np.where(lp1!=lp2)[0]
        lp1 = lp1[lfc]
        lp2 = lp2[lfc]

        lp1_all.extend(lc[lp1])
        lp2_all.extend(lc[lp2])
        # cannot-link
        ln = np.where(trY != y_idx[lpw])[0]
        sn1 = np.random.randint(len(ln), size=(len(lp1)))
        ln_all.extend(sn1)
    return lp1_all, lp2_all, ln_all

def log_large_margin_loss(margin=1.0):
    def loss_func(dij, dil):
        dd = margin + dij - dil
        dd_max = - 2 * torch.nn.functional.relu(dd) 
        return dd_max
    return loss_func

def log_gaussian_approx_large_margin_loss(margin=1.0):
    def loss_func(dij, dil):
        dd = margin + dij - dil
        lambda_ijl = 1 + torch.abs(dd) 
        log_p_Dw = log_gauss_approximation(lambda_ijl, dd)
        return log_p_Dw
    return loss_func

def create_triplets(targets):
    if not isinstance(targets, np.ndarray):
        targets = targets.numpy()
    labels_set = set(targets)
    label_to_indices = {label: np.where(targets == label)[0]
                                for label in labels_set}

    random_state = np.random.RandomState(29)

    triplets = [
                [i, random_state.choice(label_to_indices[targets[i].item()]),
                    random_state.choice(label_to_indices[
                                        np.random.choice(
                                            list(labels_set - set([targets[i].item()]))
                                        )
                                    ])
                ]
                for i in range(len(targets))]
    triplets = np.array(triplets).astype(int)
    return triplets


def plot_for_offset(power, y_max):
    # Data for plotting
    t = np.arange(0.0, 100, 1)
    s = t**power

    fig, ax = plt.subplots(figsize=(10,5))
    ax.plot(t, s)
    ax.grid()
    ax.set(xlabel='X', ylabel='x^{}'.format(power),
           title='Powers of x')

    # IMPORTANT ANIMATION CODE HERE
    # Used to keep the limits constant
    ax.set_ylim(0, y_max)

    # Used to return the plot as an image rray
    fig.canvas.draw()       # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    return image

def plot_embeddings(embeddings, targets, xlim=None, ylim=None, epoch=0, path=None):
    # mnist_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
    #             '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
    #             '#bcbd22', '#17becf']
    classes = set(targets)
    
    fig, ax = plt.subplots(figsize=(10,10))
    for classidx in classes:
        inds = np.where(targets==classidx)[0]
        ax.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5)
    if xlim:
        ax.xlim(xlim[0], xlim[1])
    if ylim:
        ax.ylim(ylim[0], ylim[1])

    ax.set(title = "Epoch: {}".format(epoch))
    ax.legend(classes)
    if path is not None:
        plt.savefig(path)
        plt.close(fig)
        return None
    else:
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)
        return image


def plot_highest_var_dim(embeddings, targets, epoch=0, tnse=True, path=None):
    if tnse:
        X_embedded = TSNE(n_components=2, n_jobs=8).fit_transform(embeddings)
    else:
        std = embeddings.std(axis=0)
        std = torch.tensor(std)
        _, indices = torch.topk(std, 2, dim=0, largest=True, sorted=False)
        indices = indices.numpy()
        X_embedded = embeddings[:, indices]
    return plot_embeddings(X_embedded, targets, epoch=epoch, path=path)
    