import time
import json
import os
import logging
import numpy as np
import copy
import torch
import quaternion
from my_config import config as cfg

from scipy.sparse import csc_matrix
from scipy.spatial.transform import Rotation


SEED = 1
np.random.seed(SEED)
gravity = np.array([0, 0, -9.81])
eps_norm = 1e-29

logger = logging.getLogger(__name__)


def split_data(nodes, edges, split_list_in, nodes_norm, edges_norm, max_E_G):
    # covert the nodes to physical
    if cfg.flag_normed:
        nodes = np.multiply(nodes, np.array(nodes_norm))
        edges = np.multiply(edges, np.array(edges_norm))

    # split the nodes [num_nodes, dim] to [[num1, dim], [num2, dim] ,[num3, dim] ...]
    # split the edged [num_edges, dim] to [[2*num1-2, dim], [2*num2-2, dim], [2*num3-2, dim]...]
    # each element in the node list is the initial condition for one simulation
    # for simulation with #num nodes, we have 2*num-2 edges
    split_list = [split_list_in[0]]
    split_list_edge = [split_list_in[0]*2 - 2]

    for i in range(1, len(split_list_in) - 1):
        split_list.append(split_list_in[i] + split_list[i - 1])
        split_list_edge.append(split_list_in[i]*2 - 2 + split_list_edge[i-1])

    if len(split_list) == 1:
        nodes = [nodes]
        edges = [edges]
    else:
        nodes = np.split(nodes, split_list, axis=0)
        edges = np.split(edges, split_list_edge, axis=0)


    # compute the constraint_paras for each of the simulation
    constraint_paras = []

    for node, edge in zip(nodes, edges):
        num_nodes = node.shape[0]
        rod_radius = node[0, 13]
        density = node[0, 14]
        length = node[0, 15]
        rest_Darboux = edge[0:num_nodes-1, 0:3]
        E = edge[0, 3] * max_E_G[0]
        G = edge[0, 4] * max_E_G[1]
        # rod_radius = 0.05
        # density = 1000.0
        # length = 0.4
        # E = 1.0e5
        # G = 2.4e5
        #print("node shape: ", node.shape)
        #print("edge shape: ", edge.shape)
        #print("Young Modulus: ", E, " Torsion Modulus: ", G, " length: ", length, " density: ", density, " radius: ", rod_radius)
        I1 = 0.25 * density * np.pi * np.power(rod_radius, 4)
        # attention, there is not length multiplied here
        J = 2.0 * I1
        alpha = 1.0e-8*np.ones(6)
        alpha[3] = 1.0/(E * I1)
        alpha[4] = 1.0/(E * I1)
        alpha[5] = 1.0/(G * J)
        constraint_paras.append([num_nodes, rod_radius, density, length, rest_Darboux, E, G, I1, J, alpha])

    return nodes, edges, constraint_paras


def split_nodes(nodes, split_list_in, nodes_norm):
    # covert the nodes to physical
    if cfg.flag_normed:
        nodes = np.multiply(nodes, np.array(nodes_norm))

    split_list = [split_list_in[0]]
    for i in range(1, len(split_list_in) - 1):
        split_list.append(split_list_in[i] + split_list[i - 1])

    if len(split_list) == 1:
        nodes = [nodes]
    else:
        nodes = np.split(nodes, split_list, axis=0)

    return nodes


def split_edges(edges, split_list_in, edges_norm):
    # covert the nodes to physical
    if cfg.flag_normed:
        edges = np.multiply(edges, np.array(edges_norm))

    split_list_edge = [split_list_in[0] * 2 - 2]
    for i in range(1, len(split_list_in) - 1):
        split_list_edge.append(split_list_in[i]*2 - 2 + split_list_edge[i-1])

    if len(split_list_edge) == 1:
        edges = [edges]
    else:
        edges = np.split(edges, split_list_edge, axis=0)

    return edges


def one_step_integration(in_state_graph, step_size, nodes_norm, split_list, flag_damp=False):
    """
    If flag_integration, this part will executed, else this part will be done by graph net.
    If flag_normed, the normalized gravity will be used,
    """
    input_graph = copy.deepcopy(in_state_graph)

    if cfg.flag_normed:
        real_nodes = np.multiply(input_graph.nodes, np.array(nodes_norm))
    else:
        real_nodes = input_graph.nodes

    #print(real_nodes.shape)
    pos = real_nodes[..., 0:3]
    vel = real_nodes[..., 3:6]
    w = np.heaviside(real_nodes[..., -1] - 1e-6, 0)
    w = np.expand_dims(w, axis=-1)
    # Using numpy array broadcasting, [N, 1] X [3] X scalar
    vel += w*gravity*step_size
    vel *= w
    pos += w*vel*step_size

    # https://github.com/moble/quaternion
    q = real_nodes[..., 6:10]
    omega = real_nodes[..., 10:13]
    omegaQ = np.zeros_like(q)
    omegaQ[..., 1:4] = omega*w
    omegaQ = quaternion.as_quat_array(omegaQ)
    q = quaternion.as_quat_array(q)
    q += 0.5*step_size*(omegaQ*q)
    q = quaternion.as_float_array(q)
    q = q/(1e-18 + np.linalg.norm(q, ord=2, axis=-1, keepdims=True))

    new_nodes = np.concatenate(
                [pos, q, real_nodes[..., -1:]],
                axis=-1)
    input_graph.update_nodes(new_nodes)
    return input_graph


def constraints_project_corr(model, input_graph, it):
    """
    This function use the trained GN to project the constraints
    """
    predicted_graph = copy.deepcopy(input_graph)

    # unsqueesze add the batch dimension
    globals = torch.from_numpy(predicted_graph.globals).float().unsqueeze(0)
    edges = torch.from_numpy(predicted_graph.edges).float().unsqueeze(0)
    nodes = torch.from_numpy(predicted_graph.nodes).float().unsqueeze(0)
    senders = torch.from_numpy(predicted_graph.senders).type(torch.LongTensor).unsqueeze(0)
    receivers = torch.from_numpy(predicted_graph.receivers).type(torch.LongTensor).unsqueeze(0)
    nodes_index = torch.from_numpy(predicted_graph.nodes_index).type(torch.LongTensor).unsqueeze(0)
    edges_index = torch.from_numpy(predicted_graph.edges_index).type(torch.LongTensor).unsqueeze(0)

    #print(globals)
    if cfg.flag_gpu:
        globals = globals.cuda()
        edges = edges.cuda()
        nodes = nodes.cuda()
        senders = senders.cuda()
        receivers = receivers.cuda()
        nodes_index = nodes_index.cuda()
        edges_index = edges_index.cuda()

    # print(cfg.flag_gpu)
    # print(nodes.is_cuda)
    output_corr, output_lambda_v, _ = model(nodes, edges, globals, senders, receivers, nodes_index, edges_index, False, it)

    if cfg.flag_gpu:
        output_corr = output_corr.cpu()
        output_lambda_v = output_lambda_v.cpu()
    output_corr = output_corr.detach().numpy().squeeze()
    output_lambda_v = output_lambda_v.detach().numpy().squeeze()
    # print("Net out")
    # print(output_corr.shape)
    # print(output_lambda_v.shape)
    predicted_graph.update_nodes(output_corr)
    predicted_graph.update_edges(output_lambda_v)
    return predicted_graph


def computeMatrixL(q):
    matrix = np.array([[-q[1], -q[2], -q[3]],
                       [q[0], -q[3], q[2]],
                       [q[3], q[0], -q[1]],
                       [-q[2], q[1], q[0]]]) * 0.5

    return matrix


def constraints_project_mix(stepped_nodes, corr_nodes, lambda_vs, constraint_paras, step_size,
                            nodes_norm, edges_norm, mask_value=1.0, is_pure=False, is_project=False):

    #print("gc mix")
    num_iteration = 1
    projected_nodes = []
    deltaX = [[]*6]
    deltaLambda = [[]*6]

    #print(stepped_nodes)
    for rod_id in range(len(stepped_nodes)):
        node = copy.deepcopy(stepped_nodes[rod_id])
        # stepped = copy.deepcopy(node)
        corr_node = copy.deepcopy(corr_nodes[rod_id])
        lambda_v = copy.deepcopy(lambda_vs[rod_id])
        #print(corr_node.shape, lambda_v.shape)

        corr_node_new = np.copy(corr_node)
        lambda_v_new = np.copy(lambda_v)

        if is_project:
            corr_node_new[:, 0] = (corr_node[:, 0] + corr_node[:, 1])*0.5
            corr_node_new[:, 1] = (corr_node[:, 0] + corr_node[:, 1])*0.5
            corr_node_new[:, 3] = 0
            corr_node_new[:, 5] = 0

            lambda_v_new[:, 0] = (lambda_v[:, 0] + lambda_v[:, 1])*0.5
            lambda_v_new[:, 1] = (lambda_v[:, 0] + lambda_v[:, 1])*0.5
            lambda_v_new[:, 3] = 0
            lambda_v_new[:, 5] = 0

            init_guess = np.concatenate([corr_node_new[1:].reshape(-1), lambda_v_new.reshape(-1)])
        #print('\n', corr_node[1], '\n', lambda_v[-1])
        else:
            init_guess = np.concatenate([corr_node[1:].reshape(-1), lambda_v.reshape(-1)])

        para = constraint_paras[rod_id]

        num_nodes, rod_radius, density, length, rest_Darboux, E, G, I1, J, alpha = para

        weight = density * np.pi * rod_radius * rod_radius * length
        node_I1 = 0.25 * weight * rod_radius * rod_radius
        one_over_dt2 = 1.0/step_size/step_size/length

        nFix = 1
        offset = 6 * (num_nodes - nFix)
        matSize = 6 * (2 * num_nodes - 1 - nFix)

        cons_lambda_com = np.zeros((num_nodes - 1, 6))

        # iterate over iterations
        for iter_s in range(num_iteration):
            col_id = []
            row_id = []
            val = []
            rhs_val = np.zeros((matSize, 1))
            # iterate over nodes
            for i in range(num_nodes):
                if node[i, -1] > 1e-6:
                    # weight
                    col_id += [6 * (i - nFix), 6 * (i - nFix) + 1, 6 * (i - nFix) + 2]
                    row_id += [6 * (i - nFix), 6 * (i - nFix) + 1, 6 * (i - nFix) + 2]
                    val += [weight, weight, weight]

                    col_id += [6 * (i - nFix) + 3, 6 * (i - nFix) + 4, 6 * (i - nFix) + 5]
                    row_id += [6 * (i - nFix) + 3, 6 * (i - nFix) + 4, 6 * (i - nFix) + 5]
                    val += [node_I1, node_I1, node_I1*2]

                if i == num_nodes-1:
                    continue

                x0 = node[i, 0:3]
                x1 = node[i + 1, 0:3]
                # q0 = node[i, 6:10]
                # q1 = node[i + 1, 6:10]
                q0 = node[i, 3:7]
                q1 = node[i + 1, 3:7]
                # our quaternion is (w, x, y, z)
                # rotation use quaternion as (x, y, z, w), need to roll
                R0 = (Rotation.from_quat(np.roll(q0, -1))).as_dcm()*length*0.5
                R1 = (Rotation.from_quat(np.roll(q1, -1))).as_dcm()*length*0.5

                L0 = computeMatrixL(q0)
                L1 = computeMatrixL(q1)
                L1TL0 = -4.0/length * np.dot(L1.transpose(), L0)

                # if i == num_nodes-2 and iter_s == 2:
                #     print(L1TL0)

                if node[i, -1] > 1e-6:
                    # -1
                    col_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                    row_id += [6 * (i - nFix), 6 * (i - nFix) + 1, 6 * (i - nFix) + 2]
                    val += [-1, -1, -1]

                    col_id += [6 * (i - nFix), 6 * (i - nFix) + 1, 6 * (i - nFix) + 2]
                    row_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                    val += [-1, -1, -1]

                    # R0
                    col_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                    row_id += [6 * (i - nFix) + 3, 6 * (i - nFix) + 3, 6 * (i - nFix) + 3]
                    val += [R0[0, 1], R0[1, 1], R0[2, 1]]

                    col_id += [6 * (i - nFix) + 3, 6 * (i - nFix) + 3, 6 * (i - nFix) + 3]
                    row_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                    val += [R0[0, 1], R0[1, 1], R0[2, 1]]

                    col_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                    row_id += [6 * (i - nFix) + 4, 6 * (i - nFix) + 4, 6 * (i - nFix) + 4]
                    val += [-R0[0, 0], -R0[1, 0], -R0[2, 0]]

                    col_id += [6 * (i - nFix) + 4, 6 * (i - nFix) + 4, 6 * (i - nFix) + 4]
                    row_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                    val += [-R0[0, 0], -R0[1, 0], -R0[2, 0]]

                    for j in range(3):
                        for k in range(3):
                            col_id += [offset + 6 * i + 3 + j, 6 * (i - nFix) + 3 + k]
                            row_id += [6 * (i - nFix) + 3 + k, offset + 6 * i + 3 + j]
                            val += [-L1TL0[j, k], -L1TL0[j, k]]

                # 1
                col_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                row_id += [6 * (i + 1 - nFix), 6 * (i + 1 - nFix) + 1, 6 * (i + 1 - nFix) + 2]
                val += [1, 1, 1]

                col_id += [6 * (i + 1 - nFix), 6 * (i + 1 - nFix) + 1, 6 * (i + 1 - nFix) + 2]
                row_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                val += [1, 1, 1]

                # R1
                col_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                row_id += [6 * (i + 1 - nFix) + 3, 6 * (i + 1 - nFix) + 3, 6 * (i + 1 - nFix) + 3]
                val += [R1[0, 1], R1[1, 1], R1[2, 1]]

                col_id += [6 * (i + 1 - nFix) + 3, 6 * (i + 1 - nFix) + 3, 6 * (i + 1 - nFix) + 3]
                row_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                val += [R1[0, 1], R1[1, 1], R1[2, 1]]

                col_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                row_id += [6 * (i + 1 - nFix) + 4, 6 * (i + 1 - nFix) + 4, 6 * (i + 1 - nFix) + 4]
                val += [-R1[0, 0], -R1[1, 0], -R1[2, 0]]

                col_id += [6 * (i + 1 - nFix) + 4, 6 * (i + 1 - nFix) + 4, 6 * (i + 1 - nFix) + 4]
                row_id += [offset + 6 * i, offset + 6 * i + 1, offset + 6 * i + 2]
                val += [-R1[0, 0], -R1[1, 0], -R1[2, 0]]

                for j in range(3):
                    for k in range(3):
                        col_id += [offset + 6 * i + 3 + j, 6 * (i + 1 - nFix) + 3 + k]
                        row_id += [6 * (i + 1 - nFix) + 3 + k, offset + 6 * i + 3 + j]
                        val += [L1TL0[k, j], L1TL0[k, j]]

                # alpha matrix
                for j in range(6):
                    col_id += [offset + 6*i + j]
                    row_id += [offset + 6*i + j]
                    val += [-one_over_dt2 * alpha[j]]

                # rhs for positions
                rhs_val[offset + 6 * i: offset + 6 * i + 3, 0] = R0[:, 2] + x0 + R1[:, 2] - x1

                # rhs for quaternion
                Q0 = quaternion.as_quat_array(q0)
                Q1 = quaternion.as_quat_array(q1)
                omega = 2.0/length * quaternion.as_float_array(Q0.conjugate() * Q1)[..., 1:4]
                omega0 = rest_Darboux[i]
                #omega_plus = omega + omega0
                omega = omega - omega0

                # rhs for alpha times lambda
                rhs_val[offset + 6 * i + 3: offset + 6 * i + 6, 0] = omega
                for j in range(6):
                    rhs_val[offset + 6 * i + j, 0] += one_over_dt2 * alpha[j] * cons_lambda_com[i, j]


            H = csc_matrix((val, (col_id, row_id)), shape=(matSize, matSize))
            m = mask_value
            mask = np.array([m,m,m,m,m,m]*(num_nodes-1)+[m,m,m,m,m,m]*(num_nodes-1))
            init_guess = init_guess*mask

            H_den = H.toarray()
            x0 = init_guess.reshape(-1, 1) # this is the switch between two modes

            R0 = rhs_val - np.dot(H_den, x0)
            p0 = copy.deepcopy(R0)
            k = 0
            if not is_pure:
                while np.linalg.norm(R0) / np.linalg.norm(rhs_val) > 0.001:
                    #print("R0 norm: ", np.linalg.norm(R0))
                    a_k = np.dot(R0.transpose(), R0) / np.dot(p0.transpose(), np.dot(H_den, p0))
                    a_k = a_k[0, 0]
                    x1 = x0 + a_k * p0
                    R1 = R0 - a_k * np.dot(H_den, p0)
                    beta = np.dot(R1.transpose(), R1)/np.dot(R0.transpose(), R0)
                    beta = beta[0, 0]
                    p1 = R1 + beta*p0
                    x0 = x1
                    R0 = R1
                    p0 = p1
                    k = k + 1
            deltaXandLambda = np.squeeze(x0)
            norm_ratio = np.linalg.norm(init_guess-deltaXandLambda)/np.linalg.norm(deltaXandLambda)

            # update lambda
            cons_lambda_com[:num_nodes - nFix, :] += deltaXandLambda[offset:offset + 6 * (num_nodes - nFix)].reshape(-1, 6)

            # update positions and quaternions
            for i in range(nFix, num_nodes):

                #q0 = node[i, 6:10]
                q0 = node[i, 3:7]
                L0 = computeMatrixL(q0)
                if node[i, -1] > 1e-6:
                    node[i, 0:3] += deltaXandLambda[6 * (i-nFix):6 * (i-nFix) + 3]
                    node[i, 3:7] += np.dot(L0, deltaXandLambda[6 * (i-nFix) + 3:6 * (i-nFix) + 6])
                    node[i, 3:7] = node[i, 3:7] / (eps_norm + np.linalg.norm(node[i, 3:7], ord=2))
            if rod_id == 0:
                deltaX = np.concatenate((np.zeros((1, 6)), deltaXandLambda[:offset].reshape(-1, 6)))
                deltaLambda = deltaXandLambda[offset:offset + 6 * (num_nodes - nFix)].reshape(-1, 6)
            else:
                deltaX = np.concatenate((deltaX, np.zeros((1, 6)), deltaXandLambda[:offset].reshape(-1, 6)))
                deltaLambda = np.concatenate((deltaLambda,
                                              deltaXandLambda[offset:offset + 6 * (num_nodes - nFix)].reshape(-1, 6)))
        projected_nodes.append(node)

    return projected_nodes, k, norm_ratio, deltaX, deltaLambda


def update_vel(origin_nodes, projected_nodes, step_size, node_include_velocity=False):
    """
    This velocity updating function is turned on when the correction part is done flag_correction is True.
    If it is true, the velocity update is done here, otherwise it just copy the input.
    """
    updated_nodes = []
    for origin_node, projected_node in zip(origin_nodes, projected_nodes):
        w = np.heaviside(origin_node[..., 13] - 1e-6, 0)
        w = np.expand_dims(w, axis=-1)

        new_pos = projected_node[:, 0:3]
        old_pos = origin_node[:, 0:3]
        new_vel = (new_pos - old_pos) * w / step_size
        # new_vel = (new_pos - old_pos)  / step_size

        if node_include_velocity:
            new_q = projected_node[:, 6:10]
        else:
            new_q = projected_node[:, 3:7]
        new_q = new_q / (eps_norm + np.linalg.norm(new_q, ord=2, axis=-1, keepdims=True))
        new_Q = quaternion.as_quat_array(new_q)

        old_q = origin_node[:, 6:10]
        old_q = old_q / (eps_norm + np.linalg.norm(old_q, ord=2, axis=-1, keepdims=True))
        old_Q_con = quaternion.as_quat_array(old_q).conjugate()

        new_omega_Q = 2 * new_Q * old_Q_con / step_size
        new_omega = quaternion.as_float_array(new_omega_Q)[:, 1:4]   # Get the x, y, z values
        updated_node = np.concatenate([new_pos, new_vel, new_q, new_omega,
                                       origin_node[:, 13:]], axis=-1)
        #print(updated_node)
        updated_nodes.append(updated_node)
    return updated_nodes


def roll_out_physics_gc_from_net_solver(model, graph, position_norm, max_E_G,
        split_list_in, steps, step_size, nodes_norm, edges_norm, mask_value=1.0, is_pure=False,
        model_type='gn', is_project=False):
    graph = copy.deepcopy(graph)
    # only support !cfg.flag_correction
    #logger.info("playing mixed net and gc")
    _, _, constraint_paras = split_data(graph.nodes, graph.edges, split_list_in, nodes_norm, edges_norm, max_E_G)

    nodes_per_step = np.zeros(shape=(steps+1, *graph.NODES_SHAPE))
    nodes_per_step[0, ...] = graph.nodes

    assert split_list_in != None, "split list input should not be empty"
    split_list = [split_list_in[0]]
    split_list_edge = [split_list_in[0]*2 - 2]

    for i in range(1, len(split_list_in) - 1):
        split_list.append(split_list_in[i] + split_list[i - 1])
        split_list_edge.append(split_list_in[i]*2 - 2 + split_list_edge[i-1])

    flag_corr_norm = True

    start_time = time.time()
    k_sum = 0
    norm_ratio_sum = 0
    norm_vector = position_norm + [1, 1, 1]
    iter_record = []
    integration_time = 0
    net_time = 0
    cg_time = 0
    update_time = 0
    time_mark = time.time()
    if model_type == "direct":
        model.init_values(graph)

    for it in range(steps):
        #print("\ntime step: ", it)
        graph.nodes = nodes_per_step[it, ...]
        nodes = split_nodes(graph.nodes, split_list_in, nodes_norm)
        # integration
        input_graph = one_step_integration(graph, step_size, nodes_norm, split_list)
        stepped_nodes = split_nodes(input_graph.nodes, split_list_in, nodes_norm)
        if flag_corr_norm:
            input_graph.nodes = np.multiply(input_graph.nodes, 1.0/np.array(norm_vector+[1, 1]))

        integration_time += time.time() - time_mark
        time_mark = time.time()

        # getting a guess
        if model_type == "gn":
            predicted_graph = constraints_project_corr(model, input_graph, it)
        elif model_type == "direct":
            predicted_graph = model.update_graph(input_graph)

        if flag_corr_norm:
            predicted_graph.nodes = np.multiply(predicted_graph.nodes, np.array(nodes_norm*np.array(norm_vector)))
            predicted_graph.edges = np.multiply(predicted_graph.edges, np.array(edges_norm))
        corr_nodes = split_nodes(predicted_graph.nodes, split_list_in, nodes_norm)
        lambda_v = split_edges(predicted_graph.edges, split_list_in, edges_norm)

        net_time += time.time() - time_mark
        time_mark = time.time()

        # cg iteration
        projected_nodes, k_ter, norm_ratio_iter, deltaX, deltaLambda = constraints_project_mix(stepped_nodes, corr_nodes, lambda_v,
            constraint_paras, step_size, nodes_norm, edges_norm, mask_value, is_pure, is_project)
        if model_type == "direct":
            model.update_init(deltaX, deltaLambda)
        k_sum += k_ter
        iter_record.append(k_ter)
        norm_ratio_sum += norm_ratio_iter

        cg_time += time.time() - time_mark
        time_mark = time.time()

        # update velocity
        updated_nodes = update_vel(nodes, projected_nodes, step_size)

        if cfg.flag_normed:
            updated_nodes = np.multiply(updated_nodes, 1 / np.array(nodes_norm))
        nodes_per_step[it + 1, ...] = np.concatenate(updated_nodes, axis=0)

        update_time += time.time() - time_mark
        time_mark = time.time()

    run_time = time.time() - start_time
    logger.info("\nSimulation run time：  {}".format(run_time))
    logger.info("Average cg iteration:  {}".format(k_sum/steps))
    logger.info("Average norm ratio:  {}".format(norm_ratio_sum/steps))
    logger.info("Integraion time: {}, ".format(integration_time))
    logger.info("net time: {}, ".format(net_time))
    logger.info("cg time: {}, ".format(cg_time))
    logger.info("update time: {}, ".format(update_time))
    output_path = "out/dummy.txt"
    if not os.path.exists(output_path):
        logger.info("Dumping to:  {}".format(output_path))
        with open(output_path, 'w') as outputfile:
            json.dump(iter_record, outputfile)
    else:
        logger.info("File exists:  {}".format(output_path))
    return nodes_per_step, iter_record, k_sum/steps
