import numpy as np
import keras.utils
from keras import backend as K
import tensorflow as tf
import scipy.io as sio
from scipy.stats import linregress
from sklearn.semi_supervised import LabelSpreading
from pylab import *
import os

def reset_weights(model):
    session = K.get_session()
    for layer in model.layers: 
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)

def getSelectedModel(run_path):
    selected_model = os.path.join(run_path, 'selected_model.txt')
    with open(selected_model, 'r') as f:
        lines = f.readlines()
    return int(float(lines[0].split(':')[1]))

def finished_acc(loss, nave, tol):
    l = np.asarray(loss)[-nave:]
    if l.shape[0] < nave:
        return False
    slope, _, _, _, _ = linregress(range(l.shape[0]), l)
    if slope >= tol:
        return True
    return False

def finished(loss, nave, tol):
    l = np.asarray(loss)[-nave:]
    if l.shape[0] < nave:
        return False
    slope, _, _, _, _ = linregress(range(l.shape[0]), l)
    if slope >= -tol:
        return True
    return False

def transferWeights(source, target):
    for i in range(len(target.layers)):
        target.layers[i].set_weights(source.layers[i].get_weights())

def generateTheta(L,endim):
    theta_=np.random.normal(size=(L,endim))
    for l in range(L):
        theta_[l,:]=theta_[l,:]/np.sqrt(np.sum(theta_[l,:]**2))
    return theta_

def preoneDWassersteinV3(p,q):
    
    return tf.cond(tf.logical_or(tf.shape(p)[0]<1, tf.shape(q)[0]<1), \
                   lambda: 0., lambda: oneDWassersteinV3(p,q))

def oneDWassersteinV3(p,q):
    
    # ~10 Times faster than V1
    psort = tf.cond(tf.shape(tf.shape(p))[0] < 2, lambda: tf.expand_dims(p,0),
                   lambda: tf.contrib.framework.sort(p,axis=0))
    qsort = tf.cond(tf.shape(tf.shape(q))[0] < 2, lambda: tf.expand_dims(q,0),
                   lambda: tf.contrib.framework.sort(q,axis=0))

    n_p=tf.shape(psort)[0]
    n_q=tf.shape(qsort)[0]
    pqmin=tf.minimum(K.min(psort,axis=0),K.min(psort,axis=0))
    psort=psort-pqmin
    qsort=qsort-pqmin

    pcum=tf.multiply(tf.cast(tf.maximum(n_p,n_q),dtype='float32'),tf.divide(tf.cumsum(psort),tf.cast(n_p,dtype='float32')))
    qcum=tf.multiply(tf.cast(tf.maximum(n_p,n_q),dtype='float32'),tf.divide(tf.cumsum(qsort),tf.cast(n_q,dtype='float32')))

    indp=tf.cast(tf.floor(tf.linspace(0.,tf.cast(n_p,dtype='float32')-1.,tf.minimum(n_p,n_q)+1)),dtype='int32')
    indq=tf.cast(tf.floor(tf.linspace(0.,tf.cast(n_q,dtype='float32')-1.,tf.minimum(n_p,n_q)+1)),dtype='int32')

    phat=tf.gather(pcum,indp[1:],axis=0)
    phat=K.concatenate((K.expand_dims(phat[0,:],0),phat[1:,:]-phat[:-1,:]),0)

    qhat=tf.gather(qcum,indq[1:],axis=0)
    qhat=K.concatenate((K.expand_dims(qhat[0,:],0),qhat[1:,:]-qhat[:-1,:]),0)

    W2=K.mean((phat-qhat)**2,axis=0)
    return W2
    

def sWasserstein(P,Q,theta,Cp=None,Cq=None):
    lambda_=1.e4
    p=K.dot(P,tf.cast(K.transpose(theta), 'float32'))
    q=K.dot(Q,tf.cast(K.transpose(theta), 'float32'))
    sw=lambda_*K.mean(oneDWassersteinV3(p,q))
    if (Cp is not None) and (Cq is not None):
        for i in range(10):
            pi=tf.gather(p,tf.squeeze(tf.where(tf.not_equal(Cp[:,i],0))))
            qi=tf.gather(q,tf.squeeze(tf.where(tf.not_equal(Cq[:,i],0))))
            sw=sw+10.*K.mean(preoneDWassersteinV3(pi,qi))
    return sw

def sWasserstein_supervised(P,Q,theta,Cp=None,Cq=None):
    p=K.dot(P,tf.cast(K.transpose(theta), 'float32'))
    q=K.dot(Q,tf.cast(K.transpose(theta), 'float32'))
    sw = 0
    if (Cp is not None) and (Cq is not None):
        for i in range(10):
            pi=tf.gather(p,tf.squeeze(tf.where(tf.not_equal(Cp[:,i],0))))
            qi=tf.gather(q,tf.squeeze(tf.where(tf.not_equal(Cq[:,i],0))))
            sw=sw+10.*K.mean(preoneDWassersteinV3(pi,qi))
    return sw

def sWasserstein_supervisedImage(P,Q,theta,Cp=None,Cq=None):
    p=K.dot(P,tf.cast(K.transpose(theta), 'float32'))
    q=K.dot(Q,tf.cast(K.transpose(theta), 'float32'))
    sw = 0
    if (Cp is not None) and (Cq is not None):
        for i in range(9):
            pi=tf.gather(p,tf.squeeze(tf.where(tf.not_equal(Cp[:,i],0))))
            qi=tf.gather(q,tf.squeeze(tf.where(tf.not_equal(Cq[:,i],0))))
            sw=sw+10.*K.mean(preoneDWassersteinV3(pi,qi))
    return sw

def sWasserstein_unsupervised(P,Q,theta):
    lambda_=1.e4
    p=K.dot(P,tf.cast(K.transpose(theta), 'float32'))
    q=K.dot(Q,tf.cast(K.transpose(theta), 'float32'))
    sw=lambda_*K.mean(oneDWassersteinV3(p,q))
    return sw

def reinitLayers(model):
    # This code reinitialize a keras/tf model
    session = K.get_session()
    for layer in model.layers: 
        if isinstance(layer, keras.engine.topology.Container):
            reinitLayers(layer)
            continue
#         print("LAYER::", layer.name)
        for v in layer.__dict__:
            v_arg = getattr(layer,v)
            if hasattr(v_arg,'initializer'):
                initializer_method = getattr(v_arg, 'initializer')
                initializer_method.run(session=session)
#                 print('reinitializing layer {}.{}'.format(layer.name, v))
    
def randperm(X,y):
    assert X.shape[0]==y.shape[0]
    ind=np.random.permutation(X.shape[0])
    X=X[ind,...]
    y=y[ind,...]
    return X,y

def batchGenerator(label,batchsize,nofclasses=2,seed=1,noflabeledsamples=None):
    N=label.shape[0]
    if not(noflabeledsamples):
        M=batchsize/nofclasses
        ind=[]
        for i in range(nofclasses):
            labelIndex=np.argwhere(label[:,i]).squeeze()
            randInd=np.random.permutation(labelIndex.shape[0])
            ind.append(labelIndex[randInd[:M]])
        ind=np.asarray(ind).reshape(-1)
        labelout=label[ind]
    else:
        np.random.seed(seed)
        portionlabeled=min(batchsize/2,noflabeledsamples*nofclasses)
        M=portionlabeled/nofclasses
        indsupervised=[]
        indunsupervised=np.array([])
        for i in range(nofclasses):
            labelIndex=np.argwhere(label[:,i]).squeeze()
            randInd=np.random.permutation(labelIndex.shape[0])
            indsupervised.append(labelIndex[randInd[:noflabeledsamples]])
            indunsupervised=np.append(indunsupervised,np.array(labelIndex[randInd[noflabeledsamples:]]))
        np.random.seed()
        ind=[]  
        for i in range(nofclasses):
            ind.append(np.random.permutation(indsupervised[i])[:M])
        ind=np.asarray(ind).reshape(-1)
        indunsupervised=np.random.permutation(indunsupervised)      
        
        labelout=np.zeros((nofclasses*(batchsize/nofclasses),nofclasses))
        labelout[:portionlabeled]=label[ind,:]
        ind=np.concatenate((ind,indunsupervised[:nofclasses*(batchsize/nofclasses)-ind.shape[0]]))
    return ind.astype(int),labelout

def optLabelSpread(fit_encodings, fit_labels, train_encodings, train_labels, lower=10, upper=60):
    '''s1,s2,s3 are scores for each round
       r1,r2,r3 are the ranges of n_neighbors we check each round
    '''
    s1 = list() # scores, round 1
    r1 = range(lower, upper+1, 10) # range round 1
    for nn in r1:
        label_spreader = LabelSpreading(kernel='knn', gamma=150,alpha=.9,n_neighbors=nn,max_iter=100)
        label_spreader.fit(fit_encodings, fit_labels)
        s1.append(label_spreader.score(train_encodings, train_labels))

    # Round 2
    r1_best = np.argmax(s1) # best number round 1
    
    # if 10 is best, just check all numbers 1-10 and return best
    if r1_best == 0:
        s2 = list()
        for nn in range(1,10):
            label_spreader = LabelSpreading(kernel='knn', gamma=150,alpha=.9,n_neighbors=nn,max_iter=100)
            label_spreader.fit(fit_encodings, fit_labels)
            s2.append(label_spreader.score(train_encodings, train_labels))
        s2.append(max(s1))            
        return np.argmax(s2)+1, max(s2)
            
    # if max of range 1 is best, we don't want to check higher, just return
    if r1_best == len(r1)-1:
        return max(r1), max(s1)
            
    # Otherwise proceed
    s2= list()
    r2 = [r1[r1_best]-5, r1[r1_best]+5] # missing lower, middle, and upper (already calc.)
    for nn in r2:
        label_spreader = LabelSpreading(kernel='knn', gamma=150,alpha=.9,n_neighbors=nn,max_iter=100)
        label_spreader.fit(fit_encodings, fit_labels)
        s2.append(label_spreader.score(train_encodings, train_labels))
    # Add middle and upper accuracies
    s2.insert(1, np.max(s1))
    r2.insert(1, r1[r1_best])

    # Round 3
    r2_best = np.argmax(s2)
    s3 = list()
    r3 = [r2[r2_best]-2, r2[r2_best]-1, r2[r2_best]+1, r2[r2_best]+2] # lower, middle , upper missing
    for nn in r3:
        label_spreader = LabelSpreading(kernel='knn', gamma=150,alpha=.9,n_neighbors=nn,max_iter=100)
        label_spreader.fit(fit_encodings, fit_labels)
        s3.append(label_spreader.score(train_encodings, train_labels))
    # Add middle
    s3.insert(2, np.max(s2))
    r3.insert(2, r2[r2_best])

    r3_best = np.argmax(s3)

    return r3[r3_best], max(s3)

############################################################
# Plotting Functions
############################################################
def plotAE(dataZ_train, ae_check, idx):
    fig = figure(figsize=(10,4))
    subplot(2,4,1)
    imshow(dataZ_train[idx][0].squeeze(),cmap='gray')
    xticks([])
    yticks([])
    subplot(2,4,2)
    imshow(ae_check[0].squeeze(),cmap='gray')
    xticks([])
    yticks([])

    subplot(2,4,3)
    imshow(dataZ_train[idx][1].squeeze(),cmap='gray')
    xticks([])
    yticks([])
    subplot(2,4,4)
    imshow(ae_check[1].squeeze(),cmap='gray')
    xticks([])
    yticks([])

    subplot(2,4,5)
    imshow(dataZ_train[idx][2].squeeze(),cmap='gray')
    xticks([])
    yticks([])
    subplot(2,4,6)
    imshow(ae_check[2].squeeze(),cmap='gray')
    xticks([])
    yticks([])

    subplot(2,4,7)
    imshow(dataZ_train[idx][3].squeeze(),cmap='gray')
    xticks([])
    yticks([])
    subplot(2,4,8)
    imshow(ae_check[3].squeeze(),cmap='gray')
    xticks([])
    yticks([])
    tight_layout()
    return fig

def plotTSNE(embedding_tsne, labelX_train, labelZ_train, xidx, zidx):
    # Assumes that the t_SNE has X and Z data concatenated together, with the X data first
    
    fig = figure(figsize=(10,5))
    ax1 = subplot(1,2,1)
    scatter(embedding_tsne[:xidx.shape[0],0],embedding_tsne[:xidx.shape[0],1],
                c = np.argmax(np.squeeze(labelX_train[xidx]),axis=1),cmap='tab10')
    xticks([])
    yticks([])
    title('Source')

    ax2 = subplot(1,2,2)
    scatter(embedding_tsne[zidx.shape[0]:,0],embedding_tsne[zidx.shape[0]:,1],
                c = np.argmax(np.squeeze(labelZ_train[zidx]),axis=1),cmap='tab10')
    xticks([])
    yticks([])
    title('Target')

    # set equal axes
    ymin, ymax = ylim()
    xmin, xmax = xlim()
    sca(ax1)
    ymin2, ymax2 = ylim()
    xmin2, xmax2 = xlim()
    ymin = min(ymin, ymin2)
    ymax = max(ymax, ymax2)
    xmin = min(xmin, xmin2)
    xmax = max(xmax, xmax2)

    xlim([xmin, xmax])
    ylim([ymin, ymax])
    sca(ax2)
    xlim([xmin, xmax])
    ylim([ymin, ymax])

    return fig

############################################################
# From DRCN paper's GitHub: https://github.com/ghif/drcn
############################################################

def get_impulse_noise(X, level):
	p = 1. - level
	Y = X * np.random.binomial(1, p, size=X.shape)
	return Y

def get_gaussian_noise(X, std):
    # X: [n, c, d1, d2] images in [0, 1]
    Y = np.random.normal(X, scale=std)
    Y = np.clip(Y, 0., 1.)
    return Y    

def get_flipped_pixels(X):
    # X: [n, c, d1, d2] images in [0, 1]
    Y = 1. - X
    Y = np.clip(Y, 0., 1.)
    return Y

def clip_relu(x):
	y = K.maximum(x, 0)
	return K.minimum(y, 1)

def augment_dynamic(X, ratio_i=0.2, ratio_g=0.2, ratio_f=0.2, gstd=0.5, ilevel=0.5):
    batch_size = X.shape[0]	

    ratio_n = ratio_i + ratio_g + ratio_f

    num_noise = int(batch_size * ratio_n)
    idx_noise = np.random.choice(range(batch_size), num_noise, replace=False)
    
    ratio_i2 = ratio_i / ratio_n
    num_impulse = int(num_noise * ratio_i2)
    i1 = 0
    i2 = num_impulse
    idx_impulse = idx_noise[i1:i2]

    ratio_g2 = ratio_g / ratio_n
    num_gaussian = int(num_noise * ratio_g2)
    i1 = i2
    i2 = i1 + num_gaussian
    idx_gaussian = idx_noise[i1:i2]
    
    ratio_f2 = ratio_f / ratio_n
    num_flip = int(num_noise * ratio_f2)
    i1 = i2
    i2 = i1 + num_flip
    idx_flip = idx_noise[i1:i2]

    Xn = np.copy(X)

    # impulse noise
    Xn[idx_impulse] = get_impulse_noise(Xn[idx_impulse], ilevel)
    Xn[idx_gaussian] = get_gaussian_noise(Xn[idx_gaussian], gstd)
    Xn[idx_flip] = get_flipped_pixels(Xn[idx_flip])
    return Xn

#########################################################################
# Modified code from https://jkjung-avt.github.io/keras-image-cropping/
#########################################################################
def random_crop(img, random_crop_size):
    # Note: image_data_format is 'channel_first'
    assert img.shape[0] == 3
    height, width = img.shape[1], img.shape[2]
    dy, dx = random_crop_size
    x = np.random.randint(0, width - dx + 1)
    y = np.random.randint(0, height - dy + 1)
    return img[:, y:(y+dy), x:(x+dx)]


def crop_generator(batches, crop_length):
    '''
    Take as input a Keras ImageGen (Iterator) and generate random
    crops from the image batches generated by the original iterator
    '''
    if not batches.y is None:
        while True:
            batch_x, batch_y = next(batches)
            batch_crops = np.zeros((batch_x.shape[0], 3, crop_length, crop_length))
            for i in range(batch_x.shape[0]):
                batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
            yield (batch_crops, batch_y)
    else:
        while True:
            batch_x = next(batches)
            batch_crops = np.zeros((batch_x.shape[0], 3, crop_length, crop_length))
            for i in range(batch_x.shape[0]):
                batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
            yield (batch_crops)
