from .node import Node
import ctypes
import json
import math
import random
from scipy.optimize import least_squares,minimize
import numpy as np
import sys
import hashlib
import time
from .config import MutationConfig, OptimizerConfig, default_lm_max_iters, normalize_operator_probs
from .precision import (
    dag_dtype_to_numpy_dtype,
    dag_dtype_to_precision_format,
    precision_format_to_numpy_dtype,
    ulp_distance,
)


def _rng_random(rng):
    if rng is None:
        return random.random()
    if isinstance(rng, np.random.Generator):
        return float(rng.random())
    return rng.random()


def _rng_randint(rng, low, high):
    if rng is None:
        return random.randint(low, high)
    if isinstance(rng, np.random.Generator):
        return int(rng.integers(low, high + 1))
    return rng.randint(low, high)


def _rng_choice(rng, options, probs=None, size=None):
    if rng is None:
        if size is not None:
            return np.random.choice(options, size=size, p=probs)
        return random.choices(options, weights=probs, k=1)[0]
    if isinstance(rng, np.random.Generator):
        return rng.choice(options, size=size, p=probs)
    if hasattr(rng, "choices"):
        if size is not None:
            count = int(np.prod(size)) if isinstance(size, tuple) else int(size)
            values = rng.choices(options, weights=probs, k=count)
            return np.asarray(values).reshape(size)
        return rng.choices(options, weights=probs, k=1)[0]
    if size is not None:
        return np.random.choice(options, size=size, p=probs)
    idx = int(_rng_random(rng) * len(options))
    return options[min(idx, len(options) - 1)]


def _rng_uniform(rng, low=0.0, high=1.0, size=None):
    if rng is None:
        return np.random.uniform(low, high, size)
    if isinstance(rng, np.random.Generator):
        return rng.uniform(low, high, size)
    if hasattr(rng, "uniform"):
        return rng.uniform(low, high, size)
    return np.random.uniform(low, high, size)


def _rng_normal(rng, size=None):
    if rng is None:
        if size is None:
            return np.random.randn()
        return np.random.randn(*size) if isinstance(size, tuple) else np.random.randn(size)
    if isinstance(rng, np.random.Generator):
        return rng.normal(size=size)
    if hasattr(rng, "normal"):
        return rng.normal(size=size)
    return np.random.randn(*size) if isinstance(size, tuple) else np.random.randn(size)

class DAG():
    """
    num_inputs: number of input nodes
    dtype: data type of the nodes
    dataset: dataset to be used for training
    """
    def __init__(self, num_inputs=1, dtype="float32", init_mutate=40, mutation_cfg=None, rng=None):
        self.input_nodes = [Node(type = 0) for _ in range(num_inputs)]
        self.num_inputs = num_inputs
        self.nodes = []
        for node in self.input_nodes:
            self.nodes.append(node)
        self.output_node = Node(type = 1)
        self.nodes.append(self.output_node)
        self.dtype = dtype
        self.M = 1
        self.N = len(self.nodes)
        self.np_type = dag_dtype_to_numpy_dtype(dtype)
        self.constant_nodes = []  

        self.optimization_error = -1
        
        self.nodes[0].link(self.nodes[-1])
        self.sorted_nodes = None
        self.id = -1
        self.parent_id = -1
        self.topsort()
        for i in range(init_mutate):
            self.mutate(mutation_cfg=mutation_cfg, rng=rng)
        self.id = -1
        self.parent_id = -1

    def save(self, filename):
        tmp_nodes = []
        for i in range(self.N):
            tmp_nodes.append({"type": self.nodes[i].type, "value": self.nodes[i].value, "prev": [self.nodes.index(node) for node in self.nodes[i].prev]})
        save_json = {"N": self.N, "nodes": tmp_nodes}
        with open(filename, "w") as f:
            json.dump(save_json, f)

    def load(self, filename):
        load_json = json.load(open(filename, "r"))
        self.N = load_json["N"]
        tmp_nodes = load_json["nodes"]
        self.nodes = [Node(type = 0) for _ in range(self.N)]
        for i in range(self.N):
            data = tmp_nodes[i]
            self.nodes[i].type = data["type"]
            self.nodes[i].value = data["value"]
            for j in range(len(data["prev"])):
                self.nodes[data["prev"][j]].link(self.nodes[i])
        self.input_nodes = []
        self.num_inputs = 0
        for i in range(self.N):
            if (self.nodes[i].type == 0):
                self.input_nodes.append(self.nodes[i])
                self.num_inputs += 1
        self.constant_nodes = [node for node in self.nodes if node.type == 7]
        self.M = 0
        for i in range(self.N):
            self.M += len(self.nodes[i].next)
        for i in range(self.N):
            if (self.nodes[i].type == 1):
                self.output_node = self.nodes[i]
                break
        self.topsort()  

    def topsort(self):
        self.sorted_nodes = []
        for i in range(self.N):
            self.nodes[i].tmp_degree = len(self.nodes[i].prev)
        for i in range(self.N):
            if self.nodes[i].tmp_degree == 0:
                self.sorted_nodes.append(self.nodes[i])
        
        for i in range(self.N):
            tmp = self.sorted_nodes[i]
            for j in range(len(tmp.next)):
                tmp.next[j].tmp_degree -= 1
                if (tmp.next[j].tmp_degree == 0):
                    self.sorted_nodes.append(tmp.next[j])
    
    def compute_cost(self):
        return len(self.nodes) - len(self.constant_nodes) - len(self.input_nodes)-1

    def optimize(self, dataset, method="cma", random_init=False, ULP=False, optimizer_cfg=None, **kwargs):
        X,Y = dataset
        rng = kwargs.get("rng")
        num_data = len(X)
        if ("eps" in kwargs):
            tol = kwargs["eps"]
        else:
            tol = 1e-6
        if optimizer_cfg is None:
            optimizer_cfg = OptimizerConfig()
        penalty = 1e18
        residual_penalty = 1e6

        def _is_finite_array(values):
            arr = np.asarray(values)
            return np.all(np.isfinite(arr))

        def _safe_numpy_eval(input_x):
            try:
                with np.errstate(all='ignore'):
                    out = self.numpy_eval(input_x)
            except Exception:
                return None
            out = np.asarray(out)
            if not _is_finite_array(out):
                return None
            return out

        def _safe_f_eval(input_x, variables):
            try:
                with np.errstate(all='ignore'):
                    out = self.f(input_x, variables)
            except Exception:
                return None
            out = np.asarray(out)
            if not _is_finite_array(out):
                return None
            return out

        def _rel_metric(out, ref):
            if out is None:
                return float(penalty)
            ref_arr = np.asarray(ref)
            if not _is_finite_array(ref_arr):
                return float(penalty)
            with np.errstate(all='ignore'):
                ratio = np.abs(out - ref_arr) / (np.abs(ref_arr) + tol)
            if not _is_finite_array(ratio):
                return float(penalty)
            return float(np.max(ratio))

        def _abs_metric(out, ref):
            if out is None:
                return float(penalty)
            ref_arr = np.asarray(ref)
            if not _is_finite_array(ref_arr):
                return float(penalty)
            with np.errstate(all='ignore'):
                diff = np.abs(out - ref_arr)
            if not _is_finite_array(diff):
                return float(penalty)
            return float(np.max(diff))

        def _ulp_metric(out, ref_values, value_format=None):
            if out is None:
                return float(penalty)
            metric_format = value_format or dag_dtype_to_precision_format(self.dtype)
            metric_dtype = precision_format_to_numpy_dtype(metric_format)
            out_arr = np.asarray(out, dtype=metric_dtype)
            ref_arr = np.asarray(ref_values, dtype=metric_dtype)
            if not _is_finite_array(out_arr) or not _is_finite_array(ref_arr):
                return float(penalty)
            ulp_dists = ulp_distance(out_arr, ref_arr, metric_format)
            if not _is_finite_array(ulp_dists):
                return float(penalty)
            return float(np.max(ulp_dists))
        #print("Optimizing with ", method)
        if (len(self.constant_nodes) == 0):
            tmp = 0
            self.gen_code()
            out = _safe_numpy_eval(X)
            tmp = _rel_metric(out, Y)
            metric_format = kwargs.get("metric_format")
            eval_dtype = kwargs.get("eval_dtype")
            _mt = kwargs.get("metric_type", "")
            eval_np_dtype = None
            if eval_dtype is not None:
                eval_np_dtype = dag_dtype_to_numpy_dtype(eval_dtype)
            elif "fp16" in kwargs:
                eval_np_dtype = np.dtype(np.float16)
            elif "fp32" in kwargs:
                eval_np_dtype = np.dtype(np.float32)
            if (ULP) or _mt == "ulp":
                if eval_np_dtype is None:
                    eval_np_dtype = np.dtype(self.np_type)
                out = _safe_numpy_eval(X.astype(eval_np_dtype))
                tmp = _ulp_metric(out, Y.astype(eval_np_dtype), value_format=metric_format)
            elif _mt == "abs" or ('abs_error' in kwargs):
                if eval_np_dtype is None:
                    eval_np_dtype = np.dtype(self.np_type)
                out = _safe_numpy_eval(X.astype(eval_np_dtype))
                tmp = _abs_metric(out, Y.astype(eval_np_dtype))
            elif ('fp32' in kwargs) or ('fp16' in kwargs) or (eval_dtype is not None):
                if eval_np_dtype is None:
                    eval_np_dtype = np.dtype(self.np_type)
                out = _safe_numpy_eval(X.astype(eval_np_dtype))
                tmp = _rel_metric(out, Y.astype(eval_np_dtype))
            return float(tmp)
        if (method == "cma"):
            import cma

            eval_dtype = kwargs.get("eval_dtype")
            _mt = kwargs.get("metric_type", "")
            eval_np_dtype = None
            if eval_dtype is not None:
                eval_np_dtype = dag_dtype_to_numpy_dtype(eval_dtype)
            metric_format = kwargs.get("metric_format") or dag_dtype_to_precision_format(eval_dtype or self.dtype)
            if eval_np_dtype is not None:
                X_cma = X.astype(eval_np_dtype)
                Y_cma = Y.astype(eval_np_dtype)
            else:
                X_cma = X
                Y_cma = Y

            def _cma_eval(variables):
                for i in range(len(self.constant_nodes)):
                    self.constant_nodes[i].value = variables[i]
                if eval_np_dtype is not None:
                    C_cast = np.array(variables).astype(eval_np_dtype)
                    out = _safe_f_eval(X_cma, C_cast)
                else:
                    out = _safe_numpy_eval(X_cma)
                return out

            def loss(variables):
                out = _cma_eval(variables)
                if (ULP) or _mt == "ulp":
                    return _ulp_metric(out, Y_cma, value_format=metric_format)
                elif _mt == "abs":
                    return _abs_metric(out, Y_cma)
                else:
                    return _rel_metric(out, Y_cma)

            p = len(self.constant_nodes)

            cma_cfg = optimizer_cfg.cma if optimizer_cfg else None
            sigma0 = cma_cfg.sigma0 if cma_cfg else 0.002
            options = {
                'popsize': cma_cfg.popsize if cma_cfg else 128,
                'maxiter': cma_cfg.maxiter if cma_cfg else 1000,
                'tolfun': 1e-8,    
                'tolupsigma': 1e15,
                'tolfacupx': 1e15,
                'verb_log': 0,     
                'seed': cma_cfg.seed if cma_cfg else 1234567,
                'verbose': -9 
            }
            best = 1e+18
            best_x = [((-1)**_rng_randint(rng, 0, 1))*(10**_rng_randint(rng, -8, 0)) for _ in range(p)]
            for t in range(10):
                x0 = []
                for i in range(p):
                    x0.append(best_x[i] + ((-1)**_rng_randint(rng, 0, 1))*(10**_rng_randint(rng, -8, 0)))
                result = cma.fmin(loss, x0, sigma0, options)

                if (result[1] < best):
                    best = result[1]
                    best_x = result[0]
                    
            for i in range(p):
                self.constant_nodes[i].value = best_x[i]
        
            return loss(best_x)
        elif ("LM" in method):
            eval_dtype = kwargs.get("eval_dtype")
            _mt = kwargs.get("metric_type", "")
            eval_np_dtype = None
            if eval_dtype is not None:
                eval_np_dtype = dag_dtype_to_numpy_dtype(eval_dtype)
            metric_format = kwargs.get("metric_format") or dag_dtype_to_precision_format(eval_dtype or self.dtype)
            if eval_np_dtype is not None:
                X_cast = X.astype(eval_np_dtype)
                Y_cast = Y.astype(eval_np_dtype)
            else:
                X_cast = X
                Y_cast = Y

            def residual(variables):
                if eval_np_dtype is not None:
                    out = _safe_f_eval(X_cast, variables.astype(eval_np_dtype))
                else:
                    out = _safe_f_eval(X, variables)
                if out is None:
                    return np.full(len(Y), residual_penalty, dtype=np.float64)
                with np.errstate(all='ignore'):
                    diff = np.asarray(out, dtype=np.float64) - np.asarray(Y_cast, dtype=np.float64)
                diff = np.asarray(diff, dtype=np.float64)
                if not _is_finite_array(diff):
                    return np.full(len(Y), residual_penalty, dtype=np.float64)
                return diff

            self.gen_code()
            p = len(self.constant_nodes)
            best = 1e+18
            best_x = 10**_rng_uniform(rng, -8, 0, p)
            signs = _rng_choice(rng, np.array([-1, 1]), None, size=p)
            best_x = signs * best_x
            
            lm_cfg = optimizer_cfg.lm if optimizer_cfg else None
            lm_enabled = True if lm_cfg is None else bool(lm_cfg.enabled)
            Total = (
                lm_cfg.max_iters
                if lm_cfg
                else default_lm_max_iters(dag_dtype_to_precision_format(kwargs.get("eval_dtype") or self.dtype))
            )
            pop_size = lm_cfg.pop_size if lm_cfg else 16
            random_noise = _rng_normal(rng, size=(Total * pop_size, p))
            if lm_enabled:
                for t in range(Total):
                    best_x0 = best_x
                    for tt in range(pop_size):
                        x0 = best_x0 + random_noise[t*pop_size+tt]
                        try: 
                            res = least_squares(residual,
                                                x0, method='lm',
                                                max_nfev=lm_cfg.max_nfev if lm_cfg else 100)
                        except Exception as e:
                            if (e == KeyboardInterrupt):
                                raise e
                            else:
                                continue
                        
                        if res.fun is None:
                            residual_error = float(penalty)
                        else:
                            res_out = np.asarray(res.fun, dtype=np.float64)
                            if (ULP) or _mt == "ulp":
                                residual_error = float(np.max(np.abs(res_out))) if _is_finite_array(res_out) else float(penalty)
                            elif _mt == "abs":
                                residual_error = float(np.max(np.abs(res_out))) if _is_finite_array(res_out) else float(penalty)
                            else:
                                with np.errstate(all='ignore'):
                                    ratio = np.abs(res_out) / (np.abs(np.asarray(Y_cast, dtype=np.float64)) + tol)
                                residual_error = float(np.max(ratio)) if _is_finite_array(ratio) else float(penalty)
                        if (residual_error < best):
                            best = residual_error
                            best_x = res.x

            def loss_precision(variables):
                out = _safe_f_eval(X_cast, variables.astype(eval_np_dtype) if eval_np_dtype is not None else variables)
                return _rel_metric(out, Y_cast)
            def loss_abs_error(variables):
                out = _safe_f_eval(X_cast, variables.astype(eval_np_dtype) if eval_np_dtype is not None else variables)
                return _abs_metric(out, Y_cast)
            def loss_ulp(variables):
                C_eval = variables.astype(eval_np_dtype) if eval_np_dtype is not None else variables
                out = _safe_f_eval(X_cast, C_eval)
                return _ulp_metric(out, Y_cast, value_format=metric_format)
            def loss(variables):
                out = _safe_f_eval(X_cast, variables.astype(eval_np_dtype) if eval_np_dtype is not None else variables)
                return _rel_metric(out, Y_cast)

            if ("nelder_mead" in method):
                nm_cfg = optimizer_cfg.nelder_mead if optimizer_cfg else None
                nm_enabled = True if nm_cfg is None else bool(nm_cfg.enabled)
                if nm_enabled:
                    nm_opts = {
                        'maxiter': nm_cfg.max_iters if nm_cfg else 100,
                        'disp': False,
                        'xatol': nm_cfg.xatol if nm_cfg else 1e-6,
                        "fatol": nm_cfg.fatol if nm_cfg else 1e-6,
                    }
                    if (ULP) or _mt == "ulp":
                        res = minimize(loss_ulp, best_x, method='nelder-mead', options=nm_opts)
                    elif _mt == "abs" or ('abs_error' in kwargs):
                        res = minimize(loss_abs_error, best_x, method='nelder-mead', options=nm_opts)
                    elif eval_dtype is not None:
                        res = minimize(loss_precision, best_x, method='nelder-mead', options=nm_opts)
                    else:
                        res = minimize(loss, best_x, method='nelder-mead', options=nm_opts)
                    best_x = res.x
                    
            
            for i in range(len(self.constant_nodes)):
                self.constant_nodes[i].value = best_x[i]
            try:
                if (ULP) or _mt == "ulp":
                    tmp = loss_ulp(best_x)
                elif _mt == "abs" or ('abs_error' in kwargs):
                    tmp = loss_abs_error(best_x)
                elif eval_dtype is not None:
                    tmp = loss_precision(best_x)
                else:
                    tmp = loss(best_x)
            except Exception as e:
                return float(penalty)
            return float(tmp)
        else:
            raise ValueError("Invalid optimization method")
    


    def copy(self):
        new_dag = DAG(self.num_inputs, self.dtype)
        new_dag.nodes = []
        new_dag.id = self.id
        new_dag.parent_id = self.parent_id
        for node in self.nodes:
            new_node = Node(node.type)
            new_node.value = node.value
            new_dag.nodes.append(new_node)
    
        for i in range(self.N):
            for j in range(len(self.nodes[i].prev)):
                new_dag.nodes[self.nodes.index(self.nodes[i].prev[j])].link(new_dag.nodes[i])
        
        new_dag.constant_nodes = [node for node in new_dag.nodes if node.type == 7]
        new_dag.M = self.M
        new_dag.N = self.N
        new_dag.input_nodes = [node for node in new_dag.nodes if node.type == 0]
        new_dag.num_inputs = self.num_inputs
        for i in range(new_dag.N):
            if (new_dag.nodes[i].type == 1):
                new_dag.output_node = new_dag.nodes[i]
                break
        
        new_dag.topsort()
        return new_dag

    def mutate(self, mutation_cfg=None, rng=None):
        if mutation_cfg is None:
            mutation_cfg = MutationConfig()
        self.parent_id = self.id
        self.id = -1
        if _rng_random(rng) >= mutation_cfg.rate:
            return
        operator_probs = normalize_operator_probs(mutation_cfg.operator_probs)
        operators = ["insertion", "deletion", "reconnection"]
        if self.M <= 1:
            op = "insertion"
        else:
            op = _rng_choice(rng, operators, [operator_probs[key] for key in operators])
            if op == "deletion" and self.M <= 1:
                op = "insertion"
        if op == "insertion":
            self.insertion(rng=rng)
        elif op == "deletion":
            self.deletion(rng=rng)
        else:
            self.reconnection(rng=rng)

    def successor(self, node):
        successor_list = [node]
        i = 0
        while (i < len(successor_list)):
            for j in range(len(successor_list[i].next)):
                if (successor_list[i].next[j] not in successor_list):
                    successor_list.append(successor_list[i].next[j])
            i += 1
        return successor_list
    
    def insertion(self, rng=None):
        edge = _rng_randint(rng, 0, self.M - 1)
        #print("edge: ", edge)
        for node in self.nodes:
            if (edge >= len(node.next)):
                edge -= len(node.next)
            else:
                tmp = node.next[edge]
                successor_list = self.successor(tmp)
                non_successor_list = []
                for tmp_node in self.nodes:
                    if (not tmp_node in successor_list):
                        non_successor_list.append(tmp_node)
                pre_index = tmp.prev.index(node)
                new_node = Node(type=_rng_randint(rng, 2, 5))
                self.nodes.append(new_node)
                tmp.prev[pre_index] = new_node
                node.next[edge] = new_node
                new_node.next.append(tmp)
                
                tmp_len = len(non_successor_list)
                p = _rng_randint(rng, 0, 2)
                if (p == 0):
                    add_new_node = Node(type = 7, rng=rng)
                    self.nodes.append(add_new_node)
                    self.constant_nodes.append(add_new_node)
                else:
                    p = _rng_randint(rng, 0, tmp_len - 1)
                    add_new_node = non_successor_list[p]
                
                add_new_node.next.append(new_node)
                if (_rng_random(rng) < 0.5):
                    new_node.prev = [node, add_new_node]
                else:
                    new_node.prev = [add_new_node, node]
               

                self.M += 2
                self.N  = len(self.nodes)
                self.remove_duplicate_constants()
                break
                
    def remove_duplicate_constants(self):
        self.topsort()
        self.forward_notate()
        for node in self.nodes:
            if (node.type>=2) and (node.type<=6):
                flag = False
                for node1 in node.prev:
                    flag = flag or node1.input_reachable
                if (not flag):              
                    node.type = 7
                    node.value = 0
                    for node1 in node.prev:
                        node1.next.remove(node)
                    node.prev = []
        self.travel_graph_and_remove()


    def travel_graph_and_remove(self):
        for node in self.nodes:
            node.visited = False
        tmp_list = [self.output_node]
        self.output_node.visited = True
        i = 0
        while (i < len(tmp_list)):
            for j in range(len(tmp_list[i].prev)):
                if (not tmp_list[i].prev[j].visited):
                    tmp_list.append(tmp_list[i].prev[j])
                    tmp_list[i].prev[j].visited = True
            i += 1
        for i in range(len(self.input_nodes)):
            if (self.input_nodes[i] not in tmp_list):
                tmp_list.append(self.input_nodes[i])
        for node in self.nodes:
            if (node in tmp_list):
                tmp_next =[]
                for node1 in node.next:
                    if (node1 in tmp_list):
                        tmp_next.append(node1)
                node.next = tmp_next
        self.nodes = tmp_list
        self.N = len(self.nodes)
        self.M = 0
        for node in self.nodes:
            self.M += len(node.next)
        self.constant_nodes = []
        for node in self.nodes:
            if (node.type == 7):
                self.constant_nodes.append(node)
        self.topsort()
        
    
    def critical_node(self):
        for node in self.nodes:
            node.forward_path = 0
            node.backward_path = 0
        for node in self.input_nodes:
            node.forward_path = 1
        self.output_node.backward_path = 1
        for node in self.sorted_nodes:
            for node1 in node.next:
                node1.forward_path += node.forward_path
        for node in reversed(self.sorted_nodes):
            for node1 in node.prev:
                node1.backward_path += node.backward_path
        
        node_list = []
        tmp= self.output_node.forward_path
        for node in self.nodes:
            if (node.forward_path*node.backward_path == tmp):
                node_list.append(node)
        return node_list


    def forward_notate(self):
        for node in self.nodes:
            node.input_reachable = False
        for node in self.input_nodes:
            node.input_reachable = True
        for node in self.sorted_nodes:
            if (node.input_reachable):
                for node1 in node.next:
                    node1.input_reachable = True
       
    def deletion(self, rng=None):
        tmp_list = []
        critical_node_list = self.critical_node()
        self.forward_notate()

        for node in self.nodes:
            if (node.type >=2) and (node.type <=6):
                tmp_list.append(node)
        
        if (len(tmp_list) == 0):
            return
        tmp_node = tmp_list[_rng_randint(rng, 0, len(tmp_list) - 1)]

        if (tmp_node in critical_node_list):
            prev_reachable_list = []
            #print("tmp_node: ", self.nodes.index(tmp_node), tmp_node.input_reachable)
            for node in tmp_node.prev:
                #print("prev: ", self.nodes.index(node), node.input_reachable)
                if (node.input_reachable):
                    prev_reachable_list.append(node)
            pp = _rng_randint(rng, 0, len(tmp_node.next) - 1)
            #print("pp: ", pp)
            for tt  in range(len(tmp_node.next)):
                node = tmp_node.next[tt]
                if (tt==pp):
                    i = node.prev.index(tmp_node)
                    p = _rng_randint(rng, 0, len(prev_reachable_list) - 1)
                    node.prev[i] = prev_reachable_list[p]
                    prev_reachable_list[p].next.append(node)
                else:
                    i = node.prev.index(tmp_node)
                    p = _rng_randint(rng, 0, len(tmp_node.prev) - 1)
                    node.prev[i] = tmp_node.prev[p]
                    tmp_node.prev[p].next.append(node)
        else:
            for node in tmp_node.next:
                i = node.prev.index(tmp_node)
                p = _rng_randint(rng, 0, len(tmp_node.prev) - 1)
                node.prev[i] = tmp_node.prev[p]
                tmp_node.prev[p].next.append(node)
        
        for node in tmp_node.prev:
            node.next.remove(tmp_node)
        
        self.travel_graph_and_remove()
    
    def reconnection(self, rng=None):
        self.critical_node()
        self.forward_notate()
        
        edge = _rng_randint(rng, 0, self.M - 1)
        for node in self.nodes:
            if (edge >= len(node.next)):
                edge -= len(node.next)
            else:
                tmp = node.next[edge]
                
                successor_list = self.successor(tmp)
                non_successor_list = []
                for tmp_node in self.nodes:
                    if (not tmp_node in successor_list):
                        non_successor_list.append(tmp_node)
                

                if (tmp.backward_path * node.forward_path == self.output_node.forward_path):
                    non_successor_list_input_reachable = []
                    for node1 in non_successor_list:
                        if (node1.input_reachable):
                            non_successor_list_input_reachable.append(node1)
                    p = _rng_randint(rng, 0, len(non_successor_list_input_reachable) - 1)
                    tmp.prev[tmp.prev.index(node)] = non_successor_list_input_reachable[p]
                    non_successor_list_input_reachable[p].next.append(tmp)
                    node.next.remove(tmp)
                    self.N = len(self.nodes)
                    self.travel_graph_and_remove()
                    break
                        
                    
                else:
                    p = _rng_randint(rng, 0, len(non_successor_list))
                    
                    if (p == len(non_successor_list)):
                        add_new_node = Node(type = 7, rng=rng)
                        self.nodes.append(add_new_node)
                        self.constant_nodes.append(add_new_node)
                    else:
                        add_new_node = non_successor_list[p]

                    tmp.prev[tmp.prev.index(node)] = add_new_node
                    add_new_node.next.append(tmp)
                    node.next.remove(tmp)
                    self.N = len(self.nodes)
                
                self.remove_duplicate_constants()
                #self.travel_graph_and_remove()

                break

    def output_func(self, output_file = None):
        if (output_file is None):
            fout = sys.stdout
        else:
            fout = open(output_file, "w")
        flag = False
        for node in self.nodes:
            for j in node.prev:
                if (j in self.input_nodes):
                    flag = True
        
        assert flag
        
        #print("def f(", end = "", file = fout)
        for i in range(self.num_inputs):
            self.input_nodes[i].name = "x" + str(i)
         #   if (i>0):
                #print(", ", end = "", file = fout)
            #print(self.input_nodes[i].name, end = "", file = fout)
        #print("):", file = fout)

        tmp = 0
        constant_tmp = 0
        if (self.sorted_nodes is None):
            self.topsort()
        for node in self.constant_nodes:
            node.name = "C" + str(constant_tmp)
            #print("    ", node.name,"=",node.value, file = fout)
            constant_tmp += 1
        for node in self.sorted_nodes:
            if (node.type == 2):
                node.name = "op" + str(tmp)
                #print("    ", node.name,"=",node.prev[0].name, "+", node.prev[1].name, file = fout)
                tmp += 1
            elif (node.type == 3):
                node.name = "op" + str(tmp)
                #print("    ", node.name,"=",node.prev[0].name, "-", node.prev[1].name,  file = fout)
                tmp += 1
            elif (node.type == 4):
                node.name = "op" + str(tmp)
                #print("    ", node.name,"=",node.prev[0].name, "*", node.prev[1].name, file = fout)
                tmp += 1
            elif (node.type == 5):
                node.name = "op" + str(tmp)
                #print("    ", node.name,"=",node.prev[0].name, "/", node.prev[1].name, file = fout)
                tmp += 1
            elif (node.type == 6):
                node.name = "op" + str(tmp)
                #print("    ", node.name,"=",node.prev[0].name, "*", node.prev[1].name, "+", node.prev[2].name, file = fout)
                tmp += 1
        #print("     return ", self.output_node.prev[0].name, file = fout)
        if (output_file is not None):
            fout.close()
        else:
            fout.flush()
    
    def gen_code(self):
        self.topsort()
        code =""
        code += "def f(x,C):\n"
        for i,node in enumerate(self.input_nodes):
            node.name = "x[:," + str(i) + "]"
        for i,node in enumerate(self.constant_nodes):
            node.name = "C[" + str(i) + "]"
            #code += "    " + node.name + " = " + str(node.value) + "\n"
        total_op = 0
        for node in self.sorted_nodes:
            if (node.type == 2):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " + " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 3):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " - " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 4):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " * " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 5):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " / " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 6):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " * " + node.prev[1].name + " + " + node.prev[2].name + "\n"
                total_op += 1
        code += "    return " + self.output_node.prev[0].name + "\n"
        #code += "Y = f(X,C)"
        local_vals = {}
        self.code = code
        exec(code, {}, local_vals)
        self.f = local_vals["f"]
   
    def numpy_eval(self,X):
        X = np.asarray(X, dtype=self.np_type)
        C = []
        for node in self.constant_nodes:
            C.append(node.value)
        C = np.array(C,dtype=self.np_type)
        return self.f(X,C)
    
    def gen_code_jacobian(self):
        self.topsort()
        code =""
        code += "def f_jacobian(x,C):\n"
        code += "    n = x.shape[0]\n"
        for i,node in enumerate(self.input_nodes):
            node.name = "x[:," + str(i) + "]"
        for i,node in enumerate(self.constant_nodes):
            node.name = "C[" + str(i) + "]"
        code += "    J_C = [0 for i in range(" + str(len(self.constant_nodes)) + ") ]\n"
        code += "    J_x = [0 for i in range(" + str(len(self.input_nodes)) + ") ]\n"
        total_op = 0
        for node in self.sorted_nodes:
            if (node.type == 2):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " + " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 3):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " - " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 4):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " * " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 5):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " / " + node.prev[1].name + "\n"
                total_op += 1
            elif (node.type == 6):
                node.name = "op" + str(total_op)
                code += "    " + node.name + " = " + node.prev[0].name + " * " + node.prev[1].name + " + " + node.prev[2].name + "\n"
                total_op += 1
            if (node.type>=2) and (node.type<=6):
                node.J_name = "J_"+node.name
                code += "    "+node.J_name + " = 0\n"
            

        code += "    J_"+ self.output_node.prev[0].name + " =  J_"+ self.output_node.prev[0].name + " + 1\n"
        for i,node in enumerate(self.input_nodes):
            node.J_name = "J_x[" + str(i) + "]"
        for i,node in enumerate(self.constant_nodes):
            node.J_name = "J_C[" + str(i) + "]"
        for node in reversed(self.sorted_nodes):
            if (node.type == 2):
                code += "    "+node.prev[0].J_name + " = "+node.prev[0].J_name + " + "+node.J_name + "\n"
                code += "    "+node.prev[1].J_name + " = "+node.prev[1].J_name + " + "+node.J_name + "\n"
            elif (node.type == 3):
                code += "    "+node.prev[0].J_name + " = "+node.prev[0].J_name + " + "+node.J_name + "\n"
                code += "    "+node.prev[1].J_name + " = "+node.prev[1].J_name + " - "+node.J_name + "\n"
            elif (node.type == 4):
                code += "    "+node.prev[0].J_name + " = "+node.prev[0].J_name + " + "+node.J_name + " * "+node.prev[1].J_name + "\n"
                code += "    "+node.prev[1].J_name + " = "+node.prev[1].J_name + " + "+node.J_name + " * "+node.prev[0].J_name + "\n"
            elif (node.type == 5):
                code += "    "+node.prev[0].J_name + " = "+node.prev[0].J_name + " + "+node.J_name + " / "+node.prev[1].J_name + "\n" 
                code += "    "+node.prev[1].J_name + " = "+node.prev[1].J_name + " - "+node.J_name + " * "+node.prev[0].J_name + " / "+node.prev[1].J_name + "**2\n"
            elif (node.type == 6):
                code += "    "+node.prev[0].J_name + " = "+node.prev[0].J_name + " + "+node.J_name + " * "+node.prev[1].J_name + "\n"
                code += "    "+node.prev[1].J_name + " = "+node.prev[1].J_name + " + "+node.J_name + " * "+node.prev[0].J_name + "\n"
                code += "    "+node.prev[2].J_name + " = "+node.prev[2].J_name + " + "+node.J_name + "\n"

        code += "    return "+self.output_node.prev[0].name+",J_C \n"
        
        local_vals = {}
        self.code_jacobian = code
        #print(code)
        exec(code, {}, local_vals)
        self.f_jacobian = local_vals["f_jacobian"]
