import numpy as np
import torch
from sklearn.metrics import log_loss, accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
import time
import random


mu_negative = np.zeros(200)
mu_positive = np.array([*[1 for i in range(10)], *[0 for i in range(190)]])
Spl = np.fromfunction(lambda i, j: np.power(0.8, abs(i-j)), shape=(200, 200))
W_star = np.dot(mu_positive, np.linalg.inv(Spl))
W_star[np.abs(W_star) < 1e-6] = 0
# print("W*:\n", W_star)





def sech(x):
    return np.divide(2*np.exp(x), np.exp(2*x)+1)

def matrix_power(M, exponent):
    LAMBDA, V = np.linalg.eig(M)
    return np.matmul(np.matmul(V, np.diag(LAMBDA**exponent)), V.T)

def topk_idx(input, k):
    idx = np.argsort(input[0].ravel())[:-k-1:-1]
    topk_idx = np.column_stack(np.unravel_index(idx, input[0].shape))
    return topk_idx

def shuffle_in_unison(a, b):
    assert a.shape[1] == b.shape[1]
    shuffled_a = np.empty(a.shape, dtype=a.dtype)
    shuffled_b = np.empty(b.shape, dtype=b.dtype)
    permutation = np.random.permutation(a.shape[1])
    for old_index, new_index in enumerate(permutation):
        shuffled_a[:, new_index] = a[:, old_index]
        shuffled_b[:, new_index] = b[:, old_index]
    return shuffled_a, shuffled_b

def f1_score(W_hat, W_star):
    supp_W_hat = set(np.nonzero(W_hat)[0].tolist())
    supp_W_star = set(np.nonzero(W_star)[0].tolist())
    if len(supp_W_hat) and len(supp_W_star):
        precision = len(supp_W_hat&supp_W_star) / len(supp_W_hat)
        recall = len(supp_W_hat&supp_W_star) / len(supp_W_star)
        return 2*precision*recall/(precision+recall)
    else:
        return np.nan




def Sign(Z):
    Z[Z<=0.5] = 0
    Z[Z>0.5] = 1
    return Z

def Sigmoid(Z):
    return 1/(1+np.exp(-Z))

def dSigmoid(Z):
    s = 1/(1+np.exp(-Z))
    dZ = s * (1-s)
    return dZ

def Relu(Z):
    return np.maximum(0,Z)
    
def dRelu(Z):
    dZ = np.ones_like(Z)
    dZ[Z < 0] = 0
    return dZ

def LReLU(Z, delta=0.01):
    A = np.where(Z > 0, Z, Z * delta)
    return A

def dLReLU(Z, delta=0.01):
    dZ = np.ones_like(Z)
    dZ[Z < 0] = delta
    return dZ

def Swish(Z, delta=10):
    return Z * (1 / (1 + np.exp(-delta * Z)))

def dSwish(Z, delta=10):
    S = 1 / (1 + np.exp(-delta * Z))
    dZ = S + delta * Z * S * (1 - S)
    return dZ

def ELU(Z, delta=1.0):
    return np.where(Z > 0, Z, delta * (np.exp(Z) - 1))

def dELU(Z, delta=1.0):
    dZ = np.where(Z > 0, 1, delta * np.exp(Z))
    return dZ




def SRelu(Z, delta=0.05):
    return 0.5 * (Z + delta*(np.sqrt(1+(Z/delta)**2)-1))

def dSRelu(Z, delta=0.05):
    return 0.5 * (np.divide(Z, np.sqrt(Z**2+delta**2)) + 1)

def SoftPlus(Z, delta=1):
    return np.log(1+np.exp(delta*Z))/delta

def dSoftPlus(Z, delta=1):
    return np.exp(delta*Z)/(1+np.exp(delta*Z))

def CrossEntropyLosses(A, Y):
    losses =  -( Y*np.log(A) + (1-Y)*np.log(1-A) )
    return losses




# optimizers
def adagrad(weights, grads, G0, lr=1e-2, epsilon=1e-6):
    G1 = G0 + np.multiply(grads, grads)
    weights = weights - np.divide(lr*grads, np.sqrt(G1)+epsilon)
    return weights, G1

def adadelta(weights, grads, G0, rho=0.95, lr=1e-2, epsilon=1e-6):
    G1 = rho*G0 + (1-rho)*np.matmul(grads.T, grads)
    weights = weights - np.matmul(lr*grads, matrix_power(G1+epsilon, -0.5))
    return weights, G1

def adam(weights, grads, G0, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
    m0 = G0["m0"]
    v0 = G0["v0"]
    t0 = G0["t0"]
    m1 = beta1*m0 + (1-beta1)*grads; m_hat = m1/(1-beta1**t0)
    v1 = beta2*v0 + (1-beta2)*np.multiply(grads, grads); v_hat = v1/(1-beta2**t0)
    t1 = t0+1
    weights = weights - lr * np.divide(m_hat, (np.sqrt(v_hat)+epsilon))
    G1 = {"m0": m1, "v0": v1, "t0": t1}
    return weights, G1

    


# early stop
def early_stop(acc_score_curve_val, threshold=1e-6, patience=500):
    record_length = len(acc_score_curve_val)
    if record_length > patience*2:
        patience_acc_record = acc_score_curve_val[record_length-patience:record_length]-acc_score_curve_val[record_length-patience]
        # if np.max(patience_acc_record)<threshold:
        #     print(patience_acc_record)
        return np.max(patience_acc_record)<threshold
    else:
        return False

# convergence
def convergence(weights_0, weights_1, threshold=1e-8):
    return np.mean(np.abs(weights_0-weights_1)<threshold)




# MLP
class Fc():

    def __init__(self, X_train, Y_train, X_val, Y_val, X_test, Y_test, k, lr, l2reg_c, delta, 
                 batch_size=1, 
                 loss_type="cross_entropy", 
                 aggregate_loss_type="atk", 
                 optimizer="adagrad", 
                 smooth_method="softplus", 
                 stop_early=True, 
                 verbose=False) -> None:
        self.X_train, self.Y_train, self.X_val, self.Y_val, self.X_test, self.Y_test = X_train, Y_train, X_val, Y_val, X_test, Y_test
        self.ch = dict()
        self.param = dict()
        self.should_i_break = False
        # records
        self.checkpoints = list()
        self.loss_curve = list()
        self.acc_score_curve_train = list()
        self.acc_score_curve_val = list()
        self.acc_score_curve_test = list()
        self.W_track = list()
        self.acc_score_train, self.precision_score_train, self.recall_score_train, self.f1_score_train = None, None, None, None
        self.acc_score_val, self.precision_score_val, self.recall_score_val, self.f1_score_val = None, None, None, None
        self.acc_score_test, self.precision_score_test, self.recall_score_test, self.f1_score_test = None, None, None, None
        self.F_1 = 0
        # hyper-parameters
        self.k = k
        self.lr = lr
        self.delta = delta
        self.l2reg_c = l2reg_c
        self.batch_size = batch_size
        self.smooth_method = smooth_method
        # settings
        self.loss_type = loss_type
        self.aggregate_loss_type = aggregate_loss_type
        self.optimizer = optimizer
        self.stop_early = stop_early
        self.verbose = verbose
        self.dims = [self.X_train.shape[0], 1]
        self.n = self.X_train.shape[1]

    def nInit(self, load_weights=False, path="./checkpoints/best_on_val/"):
        if load_weights:
            for param_name in ["W", "b", "lamda"]:
                self.param[param_name] = torch.load(path+param_name+".pt")
        else:
            # np.random.seed(100)
            self.param["W"] = np.random.randn(self.dims[1], self.dims[0]) / np.sqrt(self.dims[0]) 
            self.param["b"] = np.zeros((self.dims[1], 1)) 
            if self.aggregate_loss_type in ["matk", "smooth_matk", "sgd_matk", "smooth_sgd_matk"]:
                self.param["lamda"] = np.random.randn()
            else:
                self.param["lamda"] = np.nan


    def record1(self, X_train, Y_train, X_val, Y_val, X_test, Y_test):
        A2_train = self.forward(X=X_train, Y=Y_train, grad=False)
        A2_val = self.forward(X_val, Y_val, grad=False)
        A2_test = self.forward(X_test, Y_test, grad=False)
        Yh_train = Sign(A2_train)
        Yh_val = Sign(A2_val)
        Yh_test = Sign(A2_test)
        self.loss_curve.append(self.ch['losses'].mean())
        self.acc_score_train = accuracy_score(Y_train.T, Yh_train.T)
        self.acc_score_curve_train.append(self.acc_score_train)
        self.acc_score_val = accuracy_score(Y_val.T, Yh_val.T)
        self.acc_score_curve_val.append(self.acc_score_val)
        self.acc_score_test = accuracy_score(Y_test.T, Yh_test.T)
        self.acc_score_curve_test.append(self.acc_score_test)
        self.checkpoints.append(self.param)


    def record2(self, X_train, Y_train, X_val, Y_val, X_test, Y_test):
        F_1_now = f1_score(W_hat=self.param["W"][0], W_star=W_star)
        if self.F_1 < F_1_now:
            self.F_1 = F_1_now
        # A2_train = self.forward(X=X_train, Y=Y_train, grad=False)
        # A2_val = self.forward(X_val, Y_val, grad=False)
        # A2_test = self.forward(X_test, Y_test, grad=False)
        # Yh_train = Sign(A2_train)
        # Yh_val = Sign(A2_val)
        # Yh_test = Sign(A2_test)
        # self.precision_score_train = precision_score(Y_train.T, Yh_train.T)
        # self.recall_score_train = recall_score(Y_train.T, Yh_train.T)
        # self.f1_score_train = f1_score(Y_train.T, Yh_train.T)
        # self.precision_score_val = precision_score(Y_val.T, Yh_val.T)
        # self.recall_score_val = recall_score(Y_val.T, Yh_val.T)
        # self.f1_score_val = f1_score(Y_val.T, Yh_val.T)
        # self.precision_score_test = precision_score(Y_test.T, Yh_test.T)
        # self.recall_score_test = recall_score(Y_test.T, Yh_test.T)
        # self.f1_score_test = f1_score(Y_test.T, Yh_test.T)
    

    def forward(self, X, Y, grad=True):
        # through net
        Z = self.param['W'].dot(X) + self.param['b']; A = Sigmoid(Z)
        if grad:
            self.ch['Z'], self.ch['A'] = Z, A
        # calculating loss
        losses = CrossEntropyLosses(A, Y)
        if grad:
            self.ch['losses'] = losses
        return A


    def backward1(self, X, Y, G):
        # "aggregate loss type"
        if self.aggregate_loss_type in ["smooth_sgd_matk", "smooth_matk"]:
            # sampling_vec = np.zeros_like(self.ch['losses']); np.put(sampling_vec, [random.randint(0, sampling_vec.shape[1]-1)], 1)
            if self.smooth_method == "srelu":
                dAggregateLoss_dloss = dSRelu(self.ch['losses']-self.param["lamda"], delta=self.delta)
            elif self.smooth_method == "softplus":
                dAggregateLoss_dloss = dSoftPlus(self.ch['losses']-self.param["lamda"], delta=self.delta)
            elif self.smooth_method == "leaky_relu":
                dAggregateLoss_dloss = dLReLU(self.ch['losses']-self.param["lamda"], delta=self.delta)
            elif self.smooth_method == "swish":
                dAggregateLoss_dloss = dSwish(self.ch['losses']-self.param["lamda"], delta=self.delta)
            elif self.smooth_method == "elu":
                dAggregateLoss_dloss = dELU(self.ch['losses']-self.param["lamda"], delta=self.delta)
                
        elif self.aggregate_loss_type in ["matk", "sgd_matk"]:
            dAggregateLoss_dloss = dRelu(self.ch['losses']-self.param["lamda"])
        elif self.aggregate_loss_type == "atk":
            topk_idx_ = topk_idx(input=self.ch['losses'], k=self.k)
            topk_idx_vector = np.zeros_like(self.ch['losses']); np.put(topk_idx_vector, topk_idx_, 1)
            dAggregateLoss_dloss = topk_idx_vector
        elif self.aggregate_loss_type == "average":
            dAggregateLoss_dloss = np.ones_like(self.ch['losses'])
        elif self.aggregate_loss_type == "maximum":
            topk_idx_ = topk_idx(input=self.ch['losses'], k=1)
            topk_idx_vector = np.zeros_like(self.ch['losses']); np.put(topk_idx_vector, topk_idx_, 1)
            dAggregateLoss_dloss = topk_idx_vector
        # loss
        if self.loss_type == "cross_entropy":
            dAggregateLoss_dA = dAggregateLoss_dloss * (-np.divide(Y, self.ch['A']) + np.divide(1-Y, 1-self.ch['A']))
        dAggregateLoss_dZ = dAggregateLoss_dA * dSigmoid(self.ch['Z'])
        dAggregateLoss_W = 1./X.shape[1] * np.dot(dAggregateLoss_dZ, X.T)
        dAggregateLoss_b = 1./X.shape[1] * np.dot(dAggregateLoss_dZ, np.ones([dAggregateLoss_dZ.shape[1],1]))
        # l2-regularization loss
        # dAggregateLoss_W = dAggregateLoss_W + self.param["W"]/self.l2reg_c
        # dAggregateLoss_b = dAggregateLoss_b + self.param["b"]/self.l2reg_c
        # l1-regularization loss
        dAggregateLoss_W = dAggregateLoss_W + np.sign(self.param["W"])/(2*self.l2reg_c)
        dAggregateLoss_b = dAggregateLoss_b + np.sign(self.param["b"])/(2*self.l2reg_c)
        # update
        if self.optimizer == "gd":
            self.param["W"] = self.param["W"] - self.lr * (dAggregateLoss_W)
            self.param["b"] = self.param["b"] - self.lr * (dAggregateLoss_b)
        elif self.optimizer == "adagrad":
            self.param["W"], G1_W = adagrad(weights=self.param["W"], G0=G["G0_W"], grads=dAggregateLoss_W, lr=self.lr); G["G0_W"] = G1_W
            self.param["b"], G1_b = adagrad(weights=self.param["b"], G0=G["G0_b"], grads=dAggregateLoss_b, lr=self.lr); G["G0_b"] = G1_b
        elif self.optimizer == "adam":
            self.param["W"], G1_W = adam(weights=self.param["W"], G0=G["G0_W"], grads=dAggregateLoss_W, lr=self.lr); G["G0_W"] = G1_W
            self.param["b"], G1_b = adam(weights=self.param["b"], G0=G["G0_b"], grads=dAggregateLoss_b, lr=self.lr); G["G0_b"] = G1_b
        
        self.param["W"][np.abs(self.param["W"]) < 1e-6] = 0
        return G, dAggregateLoss_dloss
    
    
    def backward2(self, dAggregateLoss_dloss):
        if self.aggregate_loss_type in ["smooth_sgd_matk", "sgd_matk"]:
            if self.aggregate_loss_type in ["sgd_matk", "smooth_sgd_matk"]:
                dAggregateLoss_dlamda = -dAggregateLoss_dloss.sum() + self.k / self.n
            dAggregateLoss_dlamda = dAggregateLoss_dlamda  # + np.sign(self.param["lamda"])/self.l2reg_c
            self.param["lamda"] = self.param["lamda"] - self.lr * dAggregateLoss_dlamda
        elif self.aggregate_loss_type in ["smooth_matk", "matk"]:
            self.param["lamda"] = np.sort(self.ch['losses'])[:, -self.k]


    def gd(self, max_epochs=5000):
        # np.random.seed(100)
        self.nInit()
        if self.optimizer in ["adagrad", "gd"]:
            G = {"G0_W": 0, "G0_b": 0, "G0_lamda":0}
        elif self.optimizer == "adam":
            G = {"G0_W": {"m0": 0, "v0": 0, "t0": 1}, 
                 "G0_b": {"m0": 0, "v0": 0, "t0": 1}, 
                 }

        # X_train, X_test, Y_train, Y_test = train_test_split(self.X.T, self.Y.T, test_size=0.5, shuffle=True)
        # X_train, X_test, Y_train, Y_test =  X_train.T, X_test.T, Y_train.T, Y_test.T
        # X_val, X_test, Y_val, Y_test = train_test_split(X_test.T, Y_test.T, test_size=0.5, shuffle=True)
        # X_val, X_test, Y_val, Y_test = X_val.T, X_test.T, Y_val.T, Y_test.T
        X_train, Y_train, X_val, Y_val, X_test, Y_test = self.X_train, self.Y_train, self.X_val, self.Y_val, self.X_test, self.Y_test
        t=0
        i=0
        while i < max_epochs:
            # early stop
            if self.stop_early:
                if self.should_i_break:
                    break
                else:
                    if early_stop(self.acc_score_curve_val, threshold=1e-8):
                        break
            # forward and bp
            tic = time.time()
            if self.aggregate_loss_type in ["maximum", "average", "atk"]:
                i+=1
                self.forward(X_train, Y_train)
                G, dAggregateLoss_dloss = self.backward1(X_train, Y_train, G)
                # record1
                self.record1(X_train, Y_train, X_val, Y_val, X_test, Y_test)
                """
                if i > 10 and convergence(weights_0=weights_0, weights_1=self.param["W1"], threshold=1e-8):
                    break
                weights_0 = self.param["W1"]
                """
            elif self.aggregate_loss_type in ["matk", "smooth_matk"]:
                # block coordinate descent
                bcd_min_epochs = 10
                bcd_max_epochs = 500
                j=0
                while j < bcd_max_epochs:
                    j+=1
                    i+=1
                    self.forward(X_train, Y_train)
                    G, dAggregateLoss_dloss = self.backward1(X_train, Y_train, G)
                    # record1
                    self.record1(X_train, Y_train, X_val, Y_val, X_test, Y_test)
                    if j > bcd_min_epochs and convergence(weights_0=weights_0, weights_1=self.param["W"], threshold=1e-6):
                        break
                    weights_0 = self.param["W"]
                self.backward2(dAggregateLoss_dloss)
            elif self.aggregate_loss_type in ["sgd_matk", "smooth_sgd_matk"]:
                i += 1
                # stochastic gradient descent
                X_train, Y_train = shuffle_in_unison(X_train, Y_train)
                batch_size = self.batch_size
                j = 0
                while j < Y_train.shape[1]:
                    # early stop(per batch)
                    # if self.stop_early and j>200:
                    #     self.should_i_break = early_stop(self.acc_score_curve_val, threshold=1e-8, patience=500)
                    #     if self.should_i_break:
                    #         break
                    if j+batch_size < Y_train.shape[1]:
                        batched_Xj, batched_Yj = X_train[:, j:j+batch_size], Y_train[:, j:j+batch_size]
                        j += batch_size
                    else:
                        batched_Xj, batched_Yj = X_train[:, j:Y_train.shape[1]], Y_train[:, j:Y_train.shape[1]]
                        j = Y_train.shape[1]
                    self.forward(batched_Xj, batched_Yj)
                    G, dAggregateLoss_dloss = self.backward1(batched_Xj, batched_Yj, G)
                    # record1
                    # verbose(per batch)
                    # if self.verbose and j % 50000 == 0:
                    #     print("Training [Loss] and [Acc] after %d epochs and %d batches: %f, %f" %(i-1, j, self.ch['losses'].mean(), self.acc_score_test))
                    self.backward2(dAggregateLoss_dloss)
                self.record1(batched_Xj, batched_Yj, X_val, Y_val, X_test, Y_test)

            toc = time.time()
            t += toc-tic
            # verbose
            if self.verbose and i % 5 == 0:
                print("Training [Loss] and [Acc] after %i epochs: %f, %f" %(i, self.ch['losses'].mean(), self.acc_score_test))

        # back to the best-performance-on-validation-dataset checkpoint
        self.acc_score_val = max(self.acc_score_curve_val)
        self.acc_score_train = self.acc_score_curve_train[np.argmax(self.acc_score_curve_val)]
        self.acc_score_test = self.acc_score_curve_test[np.argmax(self.acc_score_curve_val)]
        self.param = self.checkpoints[np.argmax(self.acc_score_curve_val)]
        # record2
        self.record2(X_train, Y_train, X_val, Y_val, X_test, Y_test)
        # result
        result_dict = {
                       "t": float(t), 
                    #    "epochs": i, 
                    #    "k": self.k, 
                    #    "delta": self.delta, 
                    #    "l2reg_c": self.l2reg_c, 
                    #    "lr": self.lr, 
                    #    "loss_curve": self.loss_curve, 

                       "acc_score_train": float(self.acc_score_train), 
                    #    "mis_rate_train": (1-float(self.acc_score_train))*100, 
                    #    "precision_score_train": float(self.precision_score_train), 
                    #    "recall_score_train": self.recall_score_train, 
                    #    "f1_score_train": self.f1_score_train, 

                       "acc_score_val": float(self.acc_score_val),
                    #    "mis_rate_val": (1-float(self.acc_score_val))*100, 
                    #    "precision_score_val": self.precision_score_val, 
                    #    "recall_score_val": self.recall_score_val, 
                    #    "f1_score_val": self.f1_score_val, 

                       "acc_score_test": float(self.acc_score_test), 
                    #    "mis_rate_test": (1-float(self.acc_score_test))*100, 
                    #    "precision_score_test": self.precision_score_test, 
                    #    "recall_score_test": self.recall_score_test, 
                    #    "f1_score_test": self.f1_score_test, 

                    #    "acc_score_curve_train": self.acc_score_curve_train,
                    #    "acc_score_curve_val": self.acc_score_curve_val,
                    #    "acc_score_curve_test": self.acc_score_curve_test,
                        
                    #    "best_checkpoint": self.param, 
                       "W": self.param["W"], 
                    #    "W_track": self.W_track, 
                    #    "b": self.param["b"], 
                       "F_1": self.F_1
                       }
        return result_dict
    



def load_data(path="./data/"):
    data = torch.load(path+"data.pt")
    label = torch.load(path+"lable.pt")
    return data, label

# X, Y = load_data()
# X, Y = X.T, Y.T

# nn = Fc(X=X, Y=Y, 
#          k=8099, 
#          loss_type="cross_entropy", 
#          aggregate_loss_type="smooth_sgd_matk", 
#          optimizer="adam", 
#          stop_early=True, 
#          verbose=True, 
#          lr=0.005, 
#          delta=1, 
#          l2reg_c=1e4,  
#          batch_size=512, 
#          smooth_method="softplus")
# result_dict = nn.gd(max_epochs=3000)

# print(result_dict["epochs"])
# print(result_dict["acc_score_test"])
# print(result_dict["W"])


# print(result_dict["F_1"])