""" 
https://towardsdatascience.com/coding-a-2-layer-neural-network-from-scratch-in-python-4dd022d19fd2 
"""
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import time




def sech(x):
    return np.divide(2*np.exp(x), np.exp(2*x)+1)

def matrix_power(M, exponent):
    LAMBDA, V = np.linalg.eig(M)
    return np.matmul(np.matmul(V, np.diag(LAMBDA**exponent)), V.T)

def topk_idx(input, k):
    idx = np.argsort(input[0].ravel())[:-k-1:-1]
    topk_idx = np.column_stack(np.unravel_index(idx, input[0].shape))
    return topk_idx

# def random_sampling(X, Y, sample_size=1):
#     random_idx = np.random.choice(Y.shape[1], sample_size)
#     random_sampling_X = X[:, random_idx]
#     random_sampling_Y = Y[:, random_idx]
#     return random_sampling_X, random_sampling_Y




def Sigmoid(Z):
    return 1/(1+np.exp(-Z))

def dSigmoid(Z):
    s = 1/(1+np.exp(-Z))
    dZ = s * (1-s)
    return dZ

def Relu(Z):
    return np.maximum(0,Z)
    
def dRelu(Z):
    Z[Z<=0] = 0
    Z[Z>0] = 1
    return Z

def Tanh5(Z):
    return np.tanh(5*Z)

def dTanh5(Z):
    return 5*sech(5*Z)**2




# optimizers
def adagrad(weights, grads, G0, lr=1e-2, epsilon=1e-6):
    G1 = G0 + np.multiply(grads, grads)
    weights = weights - np.divide(lr*grads, np.sqrt(G1)+epsilon)
    return weights, G1

def adadelta(weights, grads, G0, rho=0.95, lr=1e-2, epsilon=1e-6):
    G1 = rho*G0 + (1-rho)*np.matmul(grads.T, grads)
    weights = weights - np.matmul(lr*grads, matrix_power(G1+epsilon, -0.5))
    return weights, G1

def adam(weights, grads, G0, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
    m0 = G0["m0"]
    v0 = G0["v0"]
    t0 = G0["t0"]
    m1 = beta1*m0 + (1-beta1)*grads; m_hat = m1/(1-beta1**t0)
    v1 = beta2*v0 + (1-beta2)*np.multiply(grads, grads); v_hat = v1/(1-beta2**t0)
    t1 = t0+1
    weights = weights - lr * np.divide(m_hat, (np.sqrt(v_hat)+epsilon))
    G1 = {"m0": m1, "v0": v1, "t0": t1}
    return weights, G1

    


# early stop
def early_stop(loss_record, threshold=1e-3, patience=10):
    i = len(loss_record)
    if i > patience+5:
        patience_loss_record = loss_record[i-patience]-loss_record[i-patience:i]
        return all(x < threshold for x in patience_loss_record)
    else:
        return False




# MLP
class Fc2():

    def __init__(self, X, Y, hidden_dim, k, lr, l2reg_c, delta, 
                 loss_type="hinge", 
                 aggregate_loss_type="atk", 
                 optimizer="gd", 
                 stop_early=True, 
                 verbose=False) -> None:
        self.X = X
        self.Y = Y
        self.ch = dict()
        self.grad = dict()
        self.param = dict()
        self.loss_curve = list()
        self.acc_score_curve_train = list()
        self.acc_score_curve_test = list()
        self.k = k
        self.lr = lr
        self.delta = delta
        self.l2reg_c = l2reg_c
        self.loss_type = loss_type
        self.aggregate_loss_type = aggregate_loss_type
        self.optimizer = optimizer
        self.stop_early = stop_early
        self.verbose = verbose
        self.dims = [self.X.shape[0], hidden_dim, 1]
        self.n = self.X.shape[1]

    def nInit(self):
        # np.random.seed(100)
        self.Yh = np.zeros((1,self.Y.shape[1]))
        self.param['W1'] = np.random.randn(self.dims[1], self.dims[0]) / np.sqrt(self.dims[0]) 
        self.param['b1'] = np.zeros((self.dims[1], 1))        
        self.param['W2'] = np.random.randn(self.dims[2], self.dims[1]) / np.sqrt(self.dims[1]) 
        self.param['b2'] = np.zeros((self.dims[2], 1))
        if self.aggregate_loss_type in ["matk", "sgd_matk", "stk", "sgd_stk"]:
            self.param["lamda"] = np.random.randn()
        else:
            self.param["lamda"] = np.nan
    
    def losses(self, Y):
        if self.loss_type == "logit":
            losses = np.log(1 + np.exp(-Y*self.ch['A2']))
        elif self.loss_type == "hinge":
            losses = Relu(1 - Y*self.ch['A2'])
        return losses

    # def aggregate_losses(self):
    #     if self.aggregate_loss_type in ["stk", "sgd_stk"]:
    #         lamda_loss = self.ch['losses']-self.param["lamda"]
    #         smooth_abs_lamda_loss = self.delta*(np.sqrt((lamda_loss/self.delta)**2+1)-1)
    #         aggregate_losses = (lamda_loss+smooth_abs_lamda_loss)/2 + self.k*self.param["lamda"]/self.ch["losses"].shape[1]
    #     elif self.aggregate_loss_type in ["matk", "sgd_matk"]:
    #         lamda_loss = self.ch['losses']-self.param["lamda"]
    #         aggregate_losses = Relu(lamda_loss) + self.k*self.param["lamda"]/self.ch["losses"].shape[1]
    #     elif self.aggregate_loss_type == "atk":
    #         topk_idx_ = topk_idx(input=self.ch['losses'], k=self.k)
    #         topk_idx_vector = np.zeros_like(self.ch['losses']); np.put(topk_idx_vector, topk_idx_, 1)
    #         aggregate_losses = self.ch['losses'] * topk_idx_vector
    #     elif self.aggregate_loss_type == "average":
    #         aggregate_losses = self.ch['losses']
    #     elif self.aggregate_loss_type == "maximum":
    #         topk_idx_ = topk_idx(input=self.ch['losses'], k=1)
    #         topk_idx_vector = np.zeros_like(self.ch['losses']); np.put(topk_idx_vector, topk_idx_, 1)
    #         aggregate_losses = self.ch['losses'] * topk_idx_vector
    #     return aggregate_losses

    def forward(self, X, Y):
        # through net
        Z1 = self.param['W1'].dot(X) + self.param['b1']; A1 = Relu(Z1)
        self.ch['Z1'], self.ch['A1'] = Z1, A1
        Z2 = self.param['W2'].dot(A1) + self.param['b2']; A2 = Z2  # Tanh5(Z2)
        self.ch['Z2'], self.ch['A2'] = Z2, A2
        # calculating loss
        losses = self.losses(Y)
        self.ch['losses'] = losses
        return A2


    def backward(self, X, Y, G):
        # "aggregate loss type"
        if self.aggregate_loss_type == "sgd_stk":
            import random
            sampling_vec = np.zeros_like(self.ch['losses']); np.put(sampling_vec, [random.randint(0, sampling_vec.shape[1]-1)], 1)
            dAggregateLoss_dloss = 0.5 * (np.divide(self.ch['losses']-self.param["lamda"], self.delta*np.sqrt(((self.ch['losses']-self.param["lamda"])/self.delta)**2+1)) + 1) * sampling_vec
            dAggregateLoss_dlamda = -dAggregateLoss_dloss + self.k / self.n
        elif self.aggregate_loss_type == "sgd_matk":
            dAggregateLoss_dloss = dRelu(self.ch['losses']-self.param["lamda"])
            dAggregateLoss_dlamda = -dAggregateLoss_dloss + self.k / self.n

        elif self.aggregate_loss_type == "stk":
            dAggregateLoss_dloss = 0.5 * (np.divide(self.ch['losses']-self.param["lamda"], self.delta*np.sqrt(((self.ch['losses']-self.param["lamda"])/self.delta)**2+1)) + 1)
            dAggregateLoss_dlamda = -dAggregateLoss_dloss.mean() + self.k / self.n
        elif self.aggregate_loss_type == "matk":
            dAggregateLoss_dloss = dRelu(self.ch['losses']-self.param["lamda"])
            dAggregateLoss_dlamda = -dAggregateLoss_dloss.mean() + self.k / self.n
        elif self.aggregate_loss_type == "atk":
            topk_idx_ = topk_idx(input=self.ch['losses'], k=self.k)
            topk_idx_vector = np.zeros_like(self.ch['losses']); np.put(topk_idx_vector, topk_idx_, 1)
            dAggregateLoss_dloss = topk_idx_vector
        elif self.aggregate_loss_type == "average":
            dAggregateLoss_dloss = np.ones_like(self.ch['losses'])
        elif self.aggregate_loss_type == "maximum":
            topk_idx_ = topk_idx(input=self.ch['losses'], k=1)
            topk_idx_vector = np.zeros_like(self.ch['losses']); np.put(topk_idx_vector, topk_idx_, 1)
            dAggregateLoss_dloss = topk_idx_vector
        # loss
        if self.loss_type == "logit":
            dAggregateLoss_dA2 = dAggregateLoss_dloss * np.divide(np.exp(-Y*self.ch['A2']) * (-Y), 1+np.exp((-Y*self.ch['A2'])))
        elif self.loss_type == "hinge":
            dAggregateLoss_dA2 = dAggregateLoss_dloss * dRelu(1 - Y*self.ch['A2']) * (-Y)
        dAggregateLoss_dZ2 = dAggregateLoss_dA2  # * dTanh5(self.ch['Z2'])
        dLoss_A1 = np.dot(self.param["W2"].T, dAggregateLoss_dZ2)
        dLoss_W2 = 1./self.ch['A1'].shape[1] * np.dot(dAggregateLoss_dZ2, self.ch['A1'].T)
        dLoss_b2 = 1./self.ch['A1'].shape[1] * np.dot(dAggregateLoss_dZ2, np.ones([dAggregateLoss_dZ2.shape[1],1]))                 
        dLoss_Z1 = dLoss_A1 * dRelu(self.ch['Z1'])
        dLoss_W1 = 1./X.shape[1] * np.dot(dLoss_Z1,X.T)
        dLoss_b1 = 1./X.shape[1] * np.dot(dLoss_Z1, np.ones([dLoss_Z1.shape[1],1]))
        # l2-normalization loss
        dLoss_W1 = dLoss_W1 + self.param["W1"]/self.l2reg_c
        dLoss_b1 = dLoss_b1 + self.param["b1"]/self.l2reg_c
        dLoss_W2 = dLoss_W2 + self.param["W2"]/self.l2reg_c
        dLoss_b2 = dLoss_b2 + self.param["b2"]/self.l2reg_c     
        # update
        if self.optimizer == "adagrad":
            self.param["W1"], G1_w1 = adagrad(weights=self.param["W1"], G0=G["G0_w1"], grads=dLoss_W1, lr=self.lr); G["G0_w1"] = G1_w1
            self.param["b1"], G1_b1 = adagrad(weights=self.param["b1"], G0=G["G0_b1"], grads=dLoss_b1, lr=self.lr); G["G0_b1"] = G1_b1
            self.param["W2"], G1_w2 = adagrad(weights=self.param["W2"], G0=G["G0_w2"], grads=dLoss_W2, lr=self.lr); G["G0_w2"] = G1_w2
            self.param["b2"], G1_b2 = adagrad(weights=self.param["b2"], G0=G["G0_b2"], grads=dLoss_b2, lr=self.lr); G["G0_b2"] = G1_b2
        elif self.optimizer == "adadelta":
            self.param["W1"], G1_w1 = adadelta(weights=self.param["W1"], G0=G["G0_w1"], grads=dLoss_W1, lr=self.lr); G["G0_w1"] = G1_w1
            self.param["b1"], G1_b1 = adadelta(weights=self.param["b1"], G0=G["G0_b1"], grads=dLoss_b1, lr=self.lr); G["G0_b1"] = G1_b1
            self.param["W2"], G1_w2 = adadelta(weights=self.param["W2"], G0=G["G0_w2"], grads=dLoss_W2, lr=self.lr); G["G0_w2"] = G1_w2
            self.param["b2"], G1_b2 = adadelta(weights=self.param["b2"], G0=G["G0_b2"], grads=dLoss_b2, lr=self.lr); G["G0_b2"] = G1_b2
        elif self.optimizer == "adam":
            self.param["W1"], G1_w1 = adam(weights=self.param["W1"], G0=G["G0_w1"], grads=dLoss_W1, lr=self.lr); G["G0_w1"] = G1_w1
            self.param["b1"], G1_b1 = adam(weights=self.param["b1"], G0=G["G0_b1"], grads=dLoss_b1, lr=self.lr); G["G0_b1"] = G1_b1
            self.param["W2"], G1_w2 = adam(weights=self.param["W2"], G0=G["G0_w2"], grads=dLoss_W2, lr=self.lr); G["G0_w2"] = G1_w2
            self.param["b2"], G1_b2 = adam(weights=self.param["b2"], G0=G["G0_b2"], grads=dLoss_b2, lr=self.lr); G["G0_b2"] = G1_b2
        elif self.optimizer == "gd":
            self.param["W1"] = self.param["W1"] - self.lr * (dLoss_W1)
            self.param["b1"] = self.param["b1"] - self.lr * (dLoss_b1)
            self.param["W2"] = self.param["W2"] - self.lr * (dLoss_W2)
            self.param["b2"] = self.param["b2"] - self.lr * (dLoss_b2)
            
        if self.aggregate_loss_type in ["stk", "matk"]:
            dAggregateLoss_dlamda = dAggregateLoss_dlamda + self.param["lamda"]/self.l2reg_c
            self.param["lamda"] = self.param["lamda"] - self.lr * dAggregateLoss_dlamda
        
        return G

    def gd(self, max_iterations=5000):
        # np.random.seed(100)
        self.nInit()
        t = 0
        i = 0
        if self.optimizer in ["adagrad", "adadelta", "gd"]:
            G = {"G0_w1": 0, "G0_w2": 0, "G0_b1": 0, "G0_b2": 0, "G0_lamda":0}
        elif self.optimizer == "adam":
            G = {"G0_w1": {"m0": 0, "v0": 0, "t0": 1}, 
                 "G0_w2": {"m0": 0, "v0": 0, "t0": 1}, 
                 "G0_b1": {"m0": 0, "v0": 0, "t0": 1}, 
                 "G0_b2": {"m0": 0, "v0": 0, "t0": 1}
                 }
        X_train, X_test, Y_train, Y_test = train_test_split(self.X.T, self.Y.T, test_size=0.5)
        X_train, X_test, Y_train, Y_test =  X_train.T, X_test.T, Y_train.T, Y_test.T

        while i < max_iterations:
            i += 1

            # if "sgd" in self.aggregate_loss_type:
            #     sampled_X, sampled_Y = random_sampling(self.X, self.Y)
            # else:
            #     sampled_X, sampled_Y = self.X, self.Y

            # forward and bp
            tic = time.time()
            A2_train = self.forward(X_train, Y_train); 
            G = self.backward(X_train, Y_train, G)
            toc = time.time()
            t += toc-tic
            # record
            A2_test = self.forward(X_test, Y_test)
            Yh_train = np.sign(A2_train)
            Yh_test = np.sign(A2_test)

            self.loss_curve.append(self.ch['losses'].mean())
            acc_score_train = accuracy_score(Yh_train.T, Y_train.T)
            self.acc_score_curve_train.append(acc_score_train)
            acc_score_test = accuracy_score(Yh_test.T, Y_test.T)
            self.acc_score_curve_test.append(acc_score_test)
            # stop early
            if self.stop_early:
                if early_stop(self.loss_curve):
                    break
            # verbose
            if self.verbose and i % 100 == 0:
                print("Training [Loss] and [Acc] after iteration %i: %f, %f" %(i, self.ch['losses'].mean(), acc_score_train))
        result_dict = {"t": float(t), 
                       "iterations": i, 
                       "k": self.k, 
                       "delta": self.delta, 
                       "l2reg_c": self.l2reg_c, 
                       "lr": self.lr, 
                       "A2_train": A2_train, 
                       "loss_curve": self.loss_curve, 
                       "acc_score_train": float(acc_score_train), 
                       "acc_score_test": float(acc_score_test),
                       "acc_score_curve_train": self.acc_score_curve_train,
                       "acc_score_curve_test": self.acc_score_curve_test
                       }
        return result_dict




import torch
def load_data(dataset_name, path="./datasets/benchmark_atk/data/"):
    ins = torch.load(path+"ins_%s.pt"%(dataset_name))
    label = torch.load(path+"lable_%s.pt"%(dataset_name))
    return ins, label

ins, label = load_data(dataset_name="segment0")
print(ins)
ins = (ins-torch.mean(ins, dim=0))/torch.std(ins, dim=0)
X, Y = ins.numpy().T, label.numpy().T

nn = Fc2(X=X, Y=Y, hidden_dim=10, k=558, loss_type="hinge", aggregate_loss_type="sgd_stk", optimizer="adagrad", stop_early=True, verbose=True, lr=0.1, delta=0.001, l2reg_c=1e3)
result_dict = nn.gd(max_iterations=20000)

print(result_dict["iterations"])
print(result_dict["acc_score"])
