"""
Code for the Submition to ICLR 2022
Causal Discovery via Cholesky Factorization
"""


import numpy as np


class DAGLearner(object):
    """

    :param object:
    :return:
    """
    def reconstruct_cholesky(self, L, z, l):
        """
        :param L:
        :param y:
        :param l:
        :return:
        """
        ImT_last_col = - z / l
        last_col = np.concatenate([ImT_last_col, [1 / l]], axis=0)
        new_ImT = np.concatenate([L, np.zeros_like(z).transpose()], axis=0)
        new_ImT = np.concatenate([new_ImT, last_col], axis=1)
        return new_ImT

    def sparsity(self, y):
        """
        :param vector:
        :return:
        """
        lmb1 = np.sum(np.abs(y), axis=0)
        return lmb1

    def move_to_ind(self, l, to_ind, from_ind):
        """
        :param l:
        :param to_ind: lower
        :param from_ind: higher
        :return:
        """
        l_from_elem = l[from_ind]
        for i in range(from_ind, to_ind, -1):
            l[i] = l[i - 1]

        l[to_ind] = l_from_elem
        return l

    def detect_hierarchy_cholesky(self, X, consider_sparse=False,
                                  diag_aug_rate=0., sparse_only=False, hard_thred=0.35):
        """
        :param X:
        :return:
        """
        X = X - np.mean(X, axis=0)
        sample_num, feature_dim = X.shape
        X = X / np.sqrt(sample_num - 1)
        C = X.transpose() @ X
        C_diag = C.diagonal()
        miC_ind = np.argmin(C_diag)
        cind = list(range(feature_dim))
        cind = self.move_to_ind(cind, 0, miC_ind)
        ImT = 1 / (np.sqrt([[C_diag[miC_ind]]]) + (diag_aug_rate * np.log(feature_dim) / sample_num))
        y = None
        l_records = []
        for ind in range(1, feature_dim):
            delta = np.mean(1 / (np.diagonal(ImT) ** 2))
            XpTXf = C[cind[:ind]][:, cind[ind:]]
            if y is None:      # first y
                y = ImT.transpose() @ XpTXf
            else:              # addtinal y only needs to compute the last dimention
                add_y = ImT[:, -1:].transpose() @ XpTXf
                y_remaind = [c for c in range(y.shape[1]) if c != next_cind]
                y = y[:, y_remaind]
                y = np.concatenate([y, add_y], axis=0)
            belta = C_diag[cind[ind:]]
            yTy_diag = np.linalg.norm(y, axis=0) ** 2
            l = np.sqrt(belta - yTy_diag)

            if diag_aug_rate != 0:
                lt = np.sqrt(belta + (diag_aug_rate * np.log(feature_dim) / sample_num) - yTy_diag)

            if consider_sparse:
                z = ImT @ y
                sparsity = self.sparsity(z)
                if diag_aug_rate != 0:
                    score = np.sqrt(np.abs(lt ** 2 - delta)) * sparsity
                else:
                    score = np.sqrt(np.abs(l ** 2 - delta)) * sparsity
            elif sparse_only:
                z = ImT @ y
                sparsity = self.sparsity(z)
                if diag_aug_rate != 0:
                    score = sparsity
                else:
                    score = sparsity
            else:
                if diag_aug_rate != 0:
                    score = np.abs(lt ** 2)
                else:
                    score = np.abs(l ** 2)
            next_cind = np.argmin(score)
            cind = self.move_to_ind(cind, ind, next_cind + ind)
            z_next = ImT @ y[:, next_cind: next_cind + 1]

            if diag_aug_rate != 0:
                lt_next = lt[next_cind: next_cind + 1]
                l_records.append(lt_next)

                ImT = self.reconstruct_cholesky(ImT, z_next, lt_next)
            else:
                l_next = l[next_cind: next_cind + 1]
                l_records.append(l_next[0])
                ImT = self.reconstruct_cholesky(ImT, z_next, l_next)
        sigmainv = np.diag(ImT) * np.sqrt(1 + diag_aug_rate * np.log(feature_dim) / sample_num)
        A = np.triu(-ImT, k=1)
        hard_thred = hard_thred + 0.1 * (0.5 - (sample_num / (np.log(feature_dim) * 650)))
        A[np.abs(A) < (hard_thred * sigmainv)] = 0

        return cind, A

    def learning(self, X, consider_sparse=False, diag_aug_rate=0., sparse_only=False, hard_thred=0.35):
        """

        :param X: batch size * feature dim
        :return:
        """

        cind, A = self.detect_hierarchy_cholesky(X, consider_sparse, diag_aug_rate, sparse_only, hard_thred)
        rind = np.argsort(cind)
        A = A[rind, :]
        W_est = A[:, rind]

        return W_est


if __name__ == '__main__':
    import utils
    import time

    algos = ['diag']
    noise_scales = [1.]
    graph_types = ["ER", "SF"]
    sem_types = ['gauss']
    d_list = [50, 100, 1000]  # node number

    # [name, consider sparse == 'vs', sparse only == 's', diag_aug_rate == "gamma * R"]
    experiments_list = [
                        # ["V", False, False, 0.0],
                        # ["S", False, True, 0.0],
                        # ["VS", True, False, 0.0],
                        ["V+", False, False, 1.0],
                        # ["S+", False, True, 3.0],
                        # ["VS+", True, False, 3.0]
                        ]
    random_seed_list = range(10)  # run each experiments on ten random seeds
    n_list = [3000]     # sample number
    edge_level = [2, 5]   # edge level
    plot_dn_curve = []
    for algo in algos:
        print("algo: ", algo)
        for graph_type in graph_types:
            print("graph type: ", graph_type)
            for sem_type in sem_types:
                print("sem tyep: ", sem_type)
                for noise_scale in noise_scales:
                    print("noise scale: ", noise_scale)
                    shd_d_list = []
                    fdr_d_list = []
                    tpr_d_list = []
                    nnz_d_list = []
                    time_d_list = []
                    no_dag_list = []
                    for d in d_list:
                        print("node num: ", d)
                        latex_recored = []
                        for e in edge_level:
                            print("edge level: ", e)
                            diag_aug_rate_result = []
                            for n in n_list:
                                print("sample num: ", n)
                                for idx, experiment in enumerate(experiments_list):
                                    # print("experiment: ", experiment)
                                    experiment_n = experiment[0]
                                    consider_sparse = experiment[1]
                                    sparse_only = experiment[2]
                                    diag_aug_rate = experiment[3]

                                    time_list = []
                                    no_dag_time = []
                                    shd_list = []
                                    fdr_list = []
                                    tpr_list = []
                                    nnz_list = []
                                    for i in random_seed_list:
                                        # print("seed: ", i)
                                        utils.set_random_seed(i)
                                        s0 = e * d
                                        B_true = utils.simulate_dag(d, s0, graph_type)
                                        W_true = utils.simulate_parameter(B_true)
                                        X, W_true = utils.simulate_linear_sem(W_true, n, sem_type, noise_scale=noise_scale)
                                        daglearner = DAGLearner()
                                        s = time.time()
                                        W_est = daglearner.learning(X, consider_sparse, diag_aug_rate, sparse_only)
                                        time_cost = time.time() - s
                                        time_list.append(time_cost)

                                        max_w = np.max(np.abs(W_est))
                                        t_W_est = W_est.copy()
                                        if not utils.is_dag(t_W_est):
                                            no_dag_time.append(1.)
                                        else:
                                            no_dag_time.append(0.)

                                        acc = utils.count_accuracy(B_true, t_W_est != 0)
                                        shd_list.append(acc['shd'])
                                        fdr_list.append(acc['fdr'])
                                        tpr_list.append(acc['tpr'])
                                        nnz_list.append(acc['nnz'])

                                    print("model: %s time_cost: %.4f fdr: %.4f (%.4f) shd: %.2f (%.2f) "
                                          "tpr: %.4f (%.4f) nnz: %.2f (%.2f)" % (experiment_n,
                                                                                 float(np.mean(time_list)),
                                                                                 float(np.mean(np.array(fdr_list))),
                                                                                 float(np.std(np.array(fdr_list))),
                                                                                 float(np.mean(np.array(shd_list))),
                                                                                 float(np.std(np.array(shd_list))),
                                                                                 float(np.mean(np.array(tpr_list))),
                                                                                 float(np.std(np.array(tpr_list))),
                                                                                 float(np.mean(np.array(nnz_list))),
                                                                                 float(np.std(np.array(nnz_list)))))





