#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec  7 17:41:24 2020

@author: pooya
"""

import numpy as np
from scipy import stats
from sklearn.cluster import KMeans
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from scipy.spatial.distance import cdist
tf.disable_v2_behavior()
sess = tf.Session()

class NN_model:
    def __init__(self, Xtr, Ytr):
        self.Xtr = Xtr
        self.dim_Xtr = Xtr.shape
        self.dim_Ytr = Ytr.shape
        self.Ytr = Ytr
    
    def predict(self,X):
        
        dist = cdist(X,self.Xtr, 'minkowski', p=2.)
        #dist = cdist(X,self.Xtr,'seuclidean',V)
        nearest = np.argmin(dist, axis=1)
        return self.Ytr[nearest,:]
    
def SRNN_init(X,Y,K_y,K_x):
    
    kmeans_y = KMeans(n_clusters=K_y, random_state=0).fit(Y)
    Y_assign=kmeans_y.labels_  
    classes = np.unique(Y_assign)
    ratio = X.shape[0]/K_x
    
    #Y_C = np.zeros(( K0*classes.shape[0]))
    #X_C = np.zeros((1, np.shape(Xtr)[1]))
    for i in classes:
        Xtr_c = X[Y_assign == i,:]
        Ytr_c = Y[Y_assign == i,:]
        K_x0=int(np.ceil(Xtr_c.shape[0]/ratio))
        kmeans = KMeans(n_clusters=K_x0, random_state=0).fit(Xtr_c)
        if i==0:
            X_C= kmeans.cluster_centers_
            Y_C = np.ones((kmeans.cluster_centers_.shape[0],1))*np.mean(Ytr_c,axis=0)
        else:
            X_C = np.concatenate((X_C,kmeans.cluster_centers_), axis=0)
            Y_C = np.concatenate((Y_C,np.ones((kmeans.cluster_centers_.shape[0],1))*np.mean(Ytr_c,axis=0)), axis=0)
        print('i=',i)
        
    Y_C = assignment(X,Y,X_C,Y_C)
    
    return X_C,Y_C

def SRNN_init2(X,Y,K_y,K_x):
    
    kmeans_y = KMeans(n_clusters=K_y, random_state=0).fit(Y)
    Y_assign=kmeans_y.labels_  
    classes = np.unique(Y_assign)
    #ratio = X.shape[0]/K_x
    K_per_class = np.ceil(K_x/K_y)
    #Y_C = np.zeros(( K0*classes.shape[0]))
    #X_C = np.zeros((1, np.shape(Xtr)[1]))
    for i in classes:
        Xtr_c = X[Y_assign == i,:]
        Ytr_c = Y[Y_assign == i,:]
        #K_x0=int(np.ceil(Xtr_c.shape[0]/ratio))
        K_x0 = min(Xtr_c.shape[0],K_per_class)
        kmeans = KMeans(n_clusters=int(K_x0), random_state=0).fit(Xtr_c)
        if i==0:
            X_C= kmeans.cluster_centers_
            Y_C = np.ones((kmeans.cluster_centers_.shape[0],1))*np.mean(Ytr_c,axis=0)
        else:
            X_C = np.concatenate((X_C,kmeans.cluster_centers_), axis=0)
            Y_C = np.concatenate((Y_C,np.ones((kmeans.cluster_centers_.shape[0],1))*np.mean(Ytr_c,axis=0)), axis=0)
        print('i=',i)
        
    Y_C = assignment(X,Y,X_C,Y_C)
    
    return X_C,Y_C

def assignment(Xtr,Ytr,X_C,Y_C):
    dist = cdist(Xtr,X_C, 'minkowski', p=2.)
    nearest = np.argmin(dist, axis=1)
    dim_X_c=X_C.shape
    #Y_C = np.zeros((dim_X_c[0],))
    for i in range(dim_X_c[0]):
        samples = (nearest == i)
        labels = Ytr[samples,:]
        label = np.mean(labels,axis=0)
        if (samples.shape[0] >0 and labels.shape[0]>0):
            Y_C[i] = label
        else:
            print('zero node detected')
    return Y_C

class SRNN_Model:
    def __init__(self, Xtr, Ytr):
        self.X_c = Xtr
        self.dim_X_c = Xtr.shape
        self.dim_Y_c = Ytr.shape
        self.Y_c = Ytr
    
    def predict(self,X):
        
        dist = cdist(X,self.X_c, 'minkowski', p=2.)
        #dist = cdist(X,self.Xtr,'seuclidean',V)
        nearest = np.argmin(dist, axis=1)
        return self.Y_c[nearest,:]
    
    def dist_cal(self,X):
        dist = cdist(X,self.X_c, 'minkowski', p=2.)
        return dist
    
    def predict_dist(self,dist):
        nearest = np.argmin(dist, axis=1)
        return self.Y_c[nearest,:]
    
    def fit(self, Xtr, Ytr, iter = 10, rej_th = 1e-8, Xv=None, Yv=None, X_test=None, Y_test=None, lr = 1e-3):
        X_C = self.X_c
        Y_C = self.Y_c
        K = X_C.shape[0]
        Y_predict = self.predict(Xtr)
        Error = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
        print('Init error:', Error)
        it = 0
        ReNN_graph = None
        X_c_0 = np.zeros((1,self.dim_X_c[1]))
        
        dists = self.dist_cal(Xtr)
        it = 0
        while it<=iter:
            it = it + 1
            for i in range(K):
                dist_no_i = np.delete(dists,i,1)
                Y_C_no_i = np.delete(self.Y_c,i,0)
                #V_C_no_i = np.delete(self.V,i,0)
                X_C_i = X_C[i,:]
                Y_C_i = Y_C[i,:]
                Y_1, r_next, _ = gen_predict_dist(dist_no_i, Y_C_no_i)
                
                Err_i = np.sum((Y_C_i-Ytr)**2,axis=1)  # squared error
                Err_n = np.sum((Y_1-Ytr)**2,axis=1)
                label = ((Err_i-Err_n)<0)*1 
                label[label==0]=-1 # 1 for i and -1 for rest
                Weight = abs(Err_i-Err_n)
                select = Weight>rej_th
                #if np.sum(select)!=select.shape[0]:
                #    print('weight bellow threshold')
                X_train = Xtr[select,:]
                Y_train = label[select]
                Weight = Weight[select]
                r_tr = r_next[select]
                
                X_c_0 [0,:] = X_C_i
                if np.sum(Y_train == 1)>0:
                    X_C_i_, ReNN_graph = centroid_optimizer(X_C_i, X_train,Y_train,r_tr,Weight,ReNN_graph,lr)
                    if np.sum(np.isnan(X_C_i_))==0:
                        self.X_c[i,:] = X_C_i_
                    else:
                        self.X_c[i,:] = X_c_0[0,:]
                
                dists[:,i] = np.sqrt(np.sum(np.power(Xtr-self.X_c[i,:],2),axis=1))
                
                Y_C = assignment_step_dist(dists, Y_C, Ytr, i)
                #print(i)
                if np.mod(i,10)==0:
                    print('it=',it ,'centroid num', i,'samples num:', X_train.shape)
                    Y_predict = self.predict_dist(dists)
                    Error_0 = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
                    if Error_0>Error:
                        print('Error increasing')
                    print(Error_0)
                    
                    Error = Error_0
                    if np.shape(X_test)!=():
                        Y_predict = self.predict(X_test)
                        Error_0_test = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
                        print(Error_0_test)
                
                
            #self.X_c[:,:] = X_c_0[:,:]
            #self.Y_c = assignment_step(self.X_c, Xtr, Ytr)
            Y_predict = self.predict(Xtr)
            Error_0 = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
            print('printing error')
            print(Error_0)
            if np.shape(X_test)!=():
                Y_predict = self.predict(X_test)
                Error_0_test = Error_0_test = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
                print(Error_0_test)
            if np.mod(it,5) == 0:
                lr = lr *0.98
                

def gen_predict_dist(dist, Y_c):
    nearest = np.argmin(dist, axis=1)    
    return Y_c[nearest,:], np.min(dist,axis=1), nearest

def assignment_step_dist(dist, Y_c, Y, i):

    nearest = np.argmin(dist, axis=1)
    samples = (nearest == i)
    labels = Y[samples,:]
    label = np.mean(labels,axis=0)
    if labels.shape[0] >0:
        Y_c[i,:] = label
    else:
        print('zero node detected')
        
    return Y_c

class make_graph:
    def __init__(self, X_0):
        dim_X = X_0.shape
        #lr = 1e-4
        self.learning_rate = tf.placeholder(tf.float32, shape=[])
        
        self.W1 = tf.placeholder(tf.float64,[None, 1], name='W1')
        self.W_1 = tf.placeholder(tf.float64,[None, 1], name='W_1')
        
        self.X_C = tf.Variable(X_0, name = 'centroid', trainable = True)
        
        self.X11 = tf.placeholder(tf.float64, [None, dim_X[1]], name='X1')
        self.X_11 = tf.placeholder(tf.float64, [None, dim_X[1]], name='X_1')
        self.r11 = tf.placeholder(tf.float64, [None, 1], name='r1')
        self.r_11 = tf.placeholder(tf.float64, [None, 1], name='r_1')
        
        ####################################################Autocorr
        #a = tf.add(self.X_C,-tf.reduce_mean(self.X_C,keepdims=True))
        #a = tf.divide(a, tf.norm(a, ord='euclidean',keepdims=True))
        #b = self.X11-tf.reduce_mean(self.X11,axis = 1,keepdims=True)
        #b = tf.divide(b, tf.norm(b, ord='euclidean', axis=1,keepdims=True))
        #c = self.X_11-tf.reduce_mean(self.X_11,axis = 1,keepdims=True)
        #c = tf.divide(c, tf.norm(c, ord='euclidean', axis=1,keepdims=True))
        #self.dist1 = tf.matmul(b,tf.transpose(a))
        #self.dist_1 = tf.matmul(c,tf.transpose(a))
        #self.out1 = tf.reduce_sum(self.dist1)
        #self.out2 = tf.reduce_sum(-tf.nn.relu(self.dist_1 + (self.r_11-1)))
        ##################################################### l1,l2,euc
        
        #self.X11 = tf.add(self.X11,-tf.reduce_mean(self.X11,axis=1,keepdims=True))
        #self.X_11 = tf.add(self.X_11,-tf.reduce_mean(self.X_11,axis=1,keepdims=True))
        
        self.dist1 = tf.add(self.X11,-self.X_C)#*self.W
        self.dist_1 = tf.add(self.X_11,-self.X_C)#*self.W
        self.norm_dist1 = tf.norm(self.dist1, ord='euclidean', axis=1, keepdims=True, name=None)
        self.norm_dist_1 = tf.norm(self.dist_1, ord='euclidean', axis=1, keepdims=True, name=None)
        
        #out1 = tf.reduce_sum(tf.nn.relu(tf.add(norm_dist1,-r11)))
        
        self.out1 = tf.reduce_sum(self.norm_dist1*self.W1)
        #self.out1 = -tf.reduce_sum(tf.nn.relu(tf.add(-self.norm_dist1, self.r11)))
        self.out2 = tf.reduce_sum(tf.nn.relu(tf.add(-self.norm_dist_1, self.r_11))*self.W_1)
        #self.out2 = 0*tf.reduce_sum(tf.nn.relu(tf.add(-self.norm_dist_1, self.r_11)))
        #self.out2 = -0*tf.reduce_sum(tf.log(tf.abs(tf.add(-self.norm_dist_1, self.r_11))))
        
        
        #self.out2 = tf.reduce_sum(tf.nn.relu(tf.add(-self.norm_dist_1, self.r_11)))
        #self.out = tf.reduce_mean(tf.add(self.out1,self.out2)) ### how summation is done over the dataset
        self.out = tf.add(self.out1,self.out2)
        #self.optimiser = tf.train.AdamOptimizer(self.learning_rate).minimize(self.out)
        #self.optimiser = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.out)
        self.optimiser = tf.train.MomentumOptimizer(self.learning_rate,0.9).minimize(self.out)
        self.init_op = tf.global_variables_initializer()

def centroid_optimizer(X_0, Xtr, Ytr,r, W, ReNN = None, lr = 1e-4):
    
    X1= Xtr[Ytr == 1,:]
    X_1= Xtr[Ytr == -1,:]
    W1= W[Ytr == 1]
    
    W_1= W[Ytr == -1]
    W_1 = W_1.reshape((W_1.shape[0],1))
    r1= r[Ytr == 1] - 0
    r1 = r1.reshape((r1.shape[0],1))
    r_1= r[Ytr == -1] + 0
    r_1 = r_1.reshape((r_1.shape[0],1))
    X_0= X_0.reshape((1,X_0.shape[0]))
    
    
    X_C_0 = np.zeros((1,X_0.shape[1]))
    X_C_ = np.zeros((1,X_0.shape[1]))
    
    steps = 200
    X_mean = np.average(X1,weights = W1,axis = 0)
    W1 = W1.reshape((W1.shape[0],1))
    E00 = gen_centroid_error (X_mean, X1, X_1, r1, r_1,W1,W_1)
    E11 = gen_centroid_error (X_0, X1, X_1, r1, r_1,W1,W_1)
    if E00 < E11:
        X_0[0,:] = X_mean
    ###### create the graph
    if ReNN == None:
        ReNN = make_graph(X_0)
    assign_op_X_0 = ReNN.X_C.assign(X_0)
    
    with sess.as_default():
        sess.run(ReNN.init_op)
        sess.run(assign_op_X_0)
        #total_batch = int(len(mnist.train.labels) / batch_size)
        dim_X1 = X1.shape
        N1 = dim_X1[0]
        dim_X_1 = X_1.shape
        N_1 = dim_X_1[0]
        #print('Error is', sess.run(ReNN.out,feed_dict={ReNN.X11: X1, ReNN.X_11: X_1, ReNN.r11: r1, ReNN.r_11: r_1, ReNN.learning_rate: lr}))
        X_C_0[0,:] = sess.run(ReNN.X_C)
        X_C_[0,:] = X_C_0[0,:]
        E0 = gen_centroid_error (X_C_0, X1, X_1, r1, r_1,W1,W_1)
        #print(np.sum(X_C_0-X_0))
        K = 0
        b1 = np.random.permutation(N1)
        b2 = np.random.permutation(N_1)
        j=0
        batch_size_N = int(np.ceil(N1/10))
        batch_size_N1 = int(np.ceil(N_1/10))
        for step in range(steps):
            #avg_cost = 0
            
            sel1 = b1[j*batch_size_N:(j+1)*(batch_size_N)]
            sel_1 = b2[j*batch_size_N1:(j+1)*batch_size_N1]
            X1_b = X1[sel1,:]; r1_b = r1[sel1,:];W1_b = W1[sel1,:];
            X_1_b = X_1[sel_1,:]; r_1_b = K*r_1[sel_1,:];W_1_b = W_1[sel_1,:];
            #if step>(steps/2):
            #    global X_C_learn
            #    X_C_learn = False
            #if np.mod(step, 10) == 0:
            #print('Error is', sess.run(ReNN.out,feed_dict={ReNN.X11: X1_b, ReNN.X_11: X_1_b, ReNN.r11: r1_b, ReNN.r_11: r_1_b,ReNN.W1: W1_b,ReNN.W_1: W_1_b, ReNN.learning_rate: lr}))
            
            #r_1_b = r_1[sel_1,:];
            #Err_surr0,_,_ = surrogate_loss(X_C_0[0,:],X1_b,X_1_b,r1_b,r_1_b,W1_b,W_1_b,K)
            #r_1_b = K*r_1[sel_1,:];
            sess.run([ReNN.optimiser, ReNN.out], feed_dict={ReNN.X11: X1_b, ReNN.X_11: X_1_b, ReNN.r11: r1_b, ReNN.r_11: r_1_b,ReNN.W1: W1_b,ReNN.W_1: W_1_b, ReNN.learning_rate: lr})
            #r_1_b = r_1[sel_1,:];
            #X_C_0[0,:] = sess.run(ReNN.X_C)
            #Err_surr1,Err_surr11,Err_surr12 = surrogate_loss(X_C_0[0,:],X1_b,X_1_b,r1_b,r_1_b,W1_b,W_1_b,K)
            #print(Err_surr0,Err_surr1,Err_surr0>Err_surr1)
            #r_1_b = K*r_1[sel_1,:];
            #print('Error is', sess.run([ReNN.out,ReNN.out1,ReNN.out2],feed_dict={ReNN.X11: X1_b, ReNN.X_11: X_1_b, ReNN.r11: r1_b, ReNN.r_11: r_1_b,ReNN.W1: W1_b,ReNN.W_1: W_1_b, ReNN.learning_rate: lr}))
            #print('smaller errs:',Err_surr11,Err_surr12)
            #sess.run([optimiser, out], feed_dict={X11: X1, X_11: X_1, r11: r1, r_11: r_1, learning_rate: lr})
            if (np.mod(step,10) == 0) & (step>1):
               lr = lr *0.98
               j = 0
               #print('error cal')
               K = K + 10/steps
               X_C_0[0,:] = sess.run(ReNN.X_C)
               #Err_surr,_,_ = surrogate_loss(X_C_0[0,:],X1,X_1,r1,r_1,W1,W_1,K)
               #X_C_0 = np.mean(X1,axis = 0, keepdims = True)
               E1 = gen_centroid_error (X_C_0, X1, X_1, r1, r_1,W1,W_1)
               #print('SGD')
               #print('current err',E1,'saved err',E0,'err_surr:', Err_surr)
               if E1 < E0:
                  X_C_[0,:] = X_C_0[0,:]
                  E0 = E1
                  
            
            #if np.mod(step, 10) == 0:
            #    print('Error is', sess.run(ReNN.out,feed_dict={ReNN.X11: X1_b, ReNN.X_11: X_1_b, ReNN.r11: r1_b, ReNN.r_11: r_1_b}))
            #    print('|W|=', np.sum(np.power(sess.run(ReNN.W),2)))
            
            #print(step)
            #X_C_0 = sess.run(ReNN.X_C)
            #print(X_C_0)
            #print('Error is', sess.run(ReNN.out,feed_dict={ReNN.X11: X1, ReNN.X_11: X_1, ReNN.r11: r1, ReNN.r_11: r_1}))
        #X_C_0 = sess.run(ReNN.X_C)
        #X_C_0 = np.median(X1,axis = 0)
        #sess.close()
    return X_C_, ReNN

def gen_centroid_error (X_c, X1, X_1, r1, r_1, W1, W_1):
    Error1 = np.sum((np.sqrt(np.sum(np.power(X1-X_c,2),axis = 1,keepdims=True))>r1)*W1)
    Error2 = np.sum((np.sqrt(np.sum(np.power(X_1-X_c,2),axis = 1,keepdims=True))<r_1)*W_1)
    #print(Error1,Error2)
    return Error1 + Error2

def surrogate_loss(X_c,X1,X_1,r1,r_1,W1,W_1,mu):
    Error1 = np.sum(np.sum((X1-X_c)**2,axis=1,keepdims=True)**(0.5)*W1)
    
    sel = (np.sqrt(np.sum(np.power(X_1-X_c,2),axis = 1,keepdims=True)))<(r_1*mu)
    r_ik = np.sqrt(np.sum((X_1-X_c)**2,axis=1,keepdims=True))
    Error2 = np.sum((r_1*mu-r_ik)*sel*W_1) 
    return Error1 + Error2, Error1, Error2