from typing import Union, Dict
import numpy as np
import networkx as nx
import json
import os
from dodiscover import EquivalenceClass


# ---------------------------------------------------
######## Manager of logs messages on console ########
# ---------------------------------------------------
class ConsoleManager:
    @staticmethod
    def data_config_msg(data_config):
        print("##############################################")
        print("Data generation with the following parameters:")
        print(data_config)
        print()
    
    @staticmethod
    def data_storing(data_dir, index):
        print(f"Storing dataset {index} in {data_dir}...", end=" ", flush=True)
    
    @staticmethod
    def metadata_storing(metadata_path):
        print(f"Storing logs in {metadata_path}...", end=" ", flush=True)

    @staticmethod
    def data_generation_mgs(index, num_datasets):
        print(f"Generating dataset {index+1}/{num_datasets}...", end=" ", flush=True)

    @staticmethod
    def method_experiments_msg(method_name : str):
        print(f"\n################# Running {method_name.upper()} #################")

    @staticmethod
    def suspended_msg(msg):
        print(f"{msg}...", end=" ", flush=True)

    @staticmethod
    def run_seed_msg(method_name : str, param_name : str, param_value : float, index, num_datasets):
        print(f"{method_name.upper()} - {param_name} {param_value} :  experiment {index+1}/{num_datasets}")
        if index+1 == num_datasets:
            print()

    @staticmethod
    def done_msg():
        print("Done!")



# -----------------------------------------------------
######## Containers of the methods' parameters ########
# -----------------------------------------------------


# class MethodParameters(metaclass=ABCMeta):
#     """Abstract class acting as container of parameters of benchmark causal discovery methods.
#     It should ease passing the parameters to ExperimentManager class.
#     The advantage with respect to a Python Dictionary, is that required parameters of an algorithm
#     are asked as argument of the constructor, which doesn't allow for missing parameters 
#     """

#     @abstractmethod
#     def parameters(self):
#         """List of parameters
#         TODO: Implement handling of more than one tunable parameter
#         """
#         raise NotImplementedError()
    
#     @abstractmethod
#     def parameters_name(self):
#         """
#         Name of the tunable parameters
#         TODO: Implement handling of more than one tunable parameter
#         """


#################### MethodParams ####################
# TODO: generalize to more than one parameter
class MethodParameters:
    def __init__(self):
        self._parameters = None
        self._parameters_name = None

    @property
    def parameters(self):
        return self._parameters
    
    @parameters.setter
    def parameters(self, values):
        self._parameters = values

    @property 
    def parameters_name(self):
        return self._parameters_name
    
    @parameters_name.setter
    def parameters_name(self, value):
        self._parameters_name = value


# #################### PC ####################
# class PCParameters(MethodParameters):
#     def __init__(
#         self,
#         alpha_values : List[float]
#     ):
#         """Arguments are the parameter for causal-learn implementation of GES.
#         Arguments not in the list will take default value

#         Parameters
#         ----------
#         lambdas : List[float]
#             The conditional independence alpha values (usually for the tuning)
#         """
#         self.alpha_values = alpha_values

#     def parameters(self):
#         return self.alpha_values

#     def parameters_name(self):
#         return "alpha"


# #################### ORDER BASED METHODS ####################
# # -----------------------------------------------------
# # NoGAM, SCORE, CAM (and partially DAS) share same
# # parameters. Handle those with a single class
# # NOTE: DAS is currently limited to a single parameter.
# # TODO: add second parameter to DAS. 
# # -----------------------------------------------------
# class OrderBasedParameters(MethodParameters):
#     def __init__(
#         self,
#         alpha_values : List[float]
#     ):
#         """Arguments are the parameter for causal-learn implementation of GES.
#         Arguments not in the list will take default value

#         Parameters
#         ----------
#         lambdas : List[float]
#             The conditional independence alpha values (usually for the tuning)
#         """
#         self.alpha_values = alpha_values

#     def parameters(self):
#         return self.alpha_values

#     def parameters_name(self):
#         return "alpha"


# class DiffanParameters(MethodParameters):
#     def __init__(
#             self,
#             # learning_rates : List[float],
#             # batch_sizes : List[int],
#             alpha_values : List[float]
#     ):
#         # self.learning_rates = learning_rates
#         # self.batch_sizes = batch_sizes
#         self.alpha_values = alpha_values

#     def parameters(self):
#         # diffan_params = {
#         #     "lr" : self.learning_rates,
#         #     "bs" : self.batch_sizes,
#         #     "alpha" : self.alpha_values
#         # }
#         # return diffan_params
#         return self.alpha_values
    
#     def parameters_name(self):
#         """Short for learning rate, batch size, alpha value
#         """
#         # return "lr_bs_alpha"
#         return "alpha"



# -----------------------------------------------
# Wrappers of causal-learn methods
# Allows to expose DoDiscover API for each method
# not implemented in the DoDiscover library
# -----------------------------------------------


# For GES, use my DoDiscover implementaiton, since it is
# faster, better performing, and allows easier parameters specification


# --------------------------------------
# Functions for PDAG to CPDAG conversion

    # The following functions implement the conversion from PDAG to
    # CPDAG that is carried after each transition to a different
    # equivalence class, after the selection and application of the
    # highest scoring insert/delete/turn operator. It consists of the
    # succesive application of three algorithms, all described in
    # Appendix C (pages 552,553) of Chickering's 2002 GES paper
    # (www.jmlr.org/papers/volume3/chickering02b/chickering02b.pdf).
    #
    # The algorithms are:

    #   1. Obtaining a consistent extension of a PDAG, implemented in
    #   the function pdag_to_dag.
    #
    #   2. Obtaining a total ordering of the edges of the extension
    #   resulting from (1). It is summarized in Fig. 13 of
    #   Chickering's paper and implemented in the function
    #   order_edges.
    #
    #   3. Labelling the edges as compelled or reversible, by which we
    #   can easily obtain the CPDAG. It is summarized in Fig. 14 of
    #   Chickering's paper and implemented in the function label_edges.

    # The above are put together in the function pdag_to_cpdag

    # NOTE!!!: Algorithm (1) is from the 1992 paper "A simple
    # algorithm to construct a consistent extension of a partially
    # oriented graph" by Dorit Dor and Michael Tarsi. There is an
    # ERROR in the summarized version in Chickering's paper. In
    # particular, the condition that N_x U Pa_x is a clique is not
    # equivalent to the condition from Dor & Torsi that every neighbor
    # of X should be adjacent to all of X's adjacent nodes. The
    # condition summarized in Chickering is more restrictive (i.e. it
    # also asks that the parents of X are adjacent to each other), but
    # this only results in an error for some graphs, and was only
    # uncovered during exhaustive testing.

# The complete pipeline: pdag -> dag -> ordered -> labelled -> cpdag


def pdag_to_cpdag(pdag):
    """
    Transform a PDAG into its corresponding CPDAG. Returns a ValueError
    exception if the given PDAG does not admit a consistent extension.

    Parameters
    ----------
    pdag : np.array
        the adjacency matrix of a given PDAG where pdag[i,j] != 0 if i
        -> j and i - j if also pdag[j,i] != 0.

    Returns
    -------
    cpdag : np.array
        the adjacency matrix of the corresponding CPDAG

    """
    # 1. Obtain a consistent extension of the pdag
    dag = pdag_to_dag(pdag)
    # 2. Recover the cpdag
    return dag_to_cpdag(dag)

# dag -> ordered -> labelled -> cpdag


def dag_to_cpdag(G):
    """
    Return the completed partially directed acyclic graph (CPDAG) that
    represents the Markov equivalence class of a given DAG. Returns a
    ValueError exception if the given graph is not a DAG.

    Parameters
    ----------
    G : np.array
        the adjacency matrix of the given graph, where G[i,j] != 0 iff i -> j

    Returns
    -------
    cpdag : np.array
        the adjacency matrix of the corresponding CPDAG

    """
    # 1. Perform a total ordering of the edges
    ordered = order_edges(G)
    # 2. Label edges as compelled or reversible
    labelled = label_edges(ordered)
    # 3. Construct CPDAG
    cpdag = np.zeros_like(labelled)
    # set compelled edges
    cpdag[labelled == 1] = labelled[labelled == 1]
    # set reversible edges
    fros, tos = np.where(labelled == -1)
    for (x, y) in zip(fros, tos):
        cpdag[x, y], cpdag[y, x] = 1, 1
    return cpdag


def pdag_to_dag(P, debug=False):
    """
    Find a consistent extension of the given PDAG. Return a ValueError
    exception if the PDAG does not admit a consistent extension.

    Parameters
    ----------
    P : np.array
        adjacency matrix representing the PDAG connectivity, where
        P[i,j] = 1 => i->j
    debug : bool, optional
        if debugging traces should be printed

    Returns
    -------
    G : np.array
        the adjacency matrix of a DAG which is a consistent extension
        (i.e. same v-structures and skeleton) of P.

    """
    G = only_directed(P)
    indexes = list(range(len(P)))  # To keep track of the real variable
    # indexes as we remove nodes from P
    while P.size > 0:
        print(P) if debug else None
        print(indexes) if debug else None
        # Select a node which
        #   1. has no outgoing edges in P (i.e. childless, is a sink)
        #   2. all its neighbors are adjacent to all its adjacent nodes
        found = False
        i = 0
        while not found and i < len(P):
            # Check condition 1
            sink = len(ch(i, P)) == 0
            # Check condition 2
            n_i = neighbors(i, P)
            adj_i = adj(i, P)
            adj_neighbors = np.all([adj_i - {y} <= adj(y, P) for y in n_i])
            print("   i:", i, ": n=", n_i, "adj=", adj_i, "ch=", ch(i, P)) if debug else None
            found = sink and adj_neighbors
            # If found, orient all incident undirected edges and
            # remove i from the subgraph
            if found:
                print("  Found candidate %d (%d)" % (i, indexes[i])) if debug else None
                # Orient all incident undirected edges
                real_i = indexes[i]
                real_neighbors = [indexes[j] for j in n_i]
                for j in real_neighbors:
                    G[j, real_i] = 1
                # Remove i and its incident (directed and undirected edges)
                all_but_i = list(set(range(len(P))) - {i})
                P = P[all_but_i, :][:, all_but_i]
                indexes.remove(real_i)  # to keep track of the real
                # variable indices
            else:
                i += 1
        # A node which satisfies conditions 1,2 exists iff the
        # PDAG admits a consistent extension
        if not found:
            raise ValueError("PDAG does not admit consistent extension")
    return G


def order_edges(G):
    """
    Find a total ordering of the edges in DAG G, as an intermediate
    step to obtaining the CPDAG representing the Markov equivalence class to
    which it belongs. Raises a ValueError exception if G is not a DAG.

    Parameters
    ----------
    G : np.array
        the adjacency matrix of a graph G, where G[i,j] != 0 iff i -> j.

    Returns
    -------
    ordered : np.array
       the adjacency matrix of the graph G, but with labelled edges,
       i.e. i -> j is has label x iff ordered[i,j] = x.

    """
    if not is_dag(G):
        raise ValueError("The given graph is not a DAG")
    # i.e. if i -> j, then i appears before j in order
    order = topological_ordering(G)
    # You can check the above by seeing that np.all([i == order[pos[i]] for i in range(p)]) is True
    # Unlabelled edges as marked with -1
    ordered = (G != 0).astype(int) * -1
    i = 1
    while (ordered == -1).any():
        # let y be the lowest ordered node that has an unlabelled edge
        # incident to it
        froms, tos = np.where(ordered == -1)
        with_unlabelled = np.unique(np.hstack((froms, tos)))
        y = sort(with_unlabelled, reversed(order))[0]
        # let x be the highest ordered node s.t. the edge x -> y
        # exists and is unlabelled
        unlabelled_parents_y = np.where(ordered[:, y] == -1)[0]
        x = sort(unlabelled_parents_y, order)[0]
        ordered[x, y] = i
        i += 1
    return ordered


def label_edges(ordered):
    """Given a DAG with edges labelled according to a total ordering,
    label each edge as being compelled or reverisble.

    Parameters
    ----------
    ordered : np.array
        the adjacency matrix of a graph, with the edges labelled
        according to a total ordering.

    Returns
    -------
    labelled : np.array
        the adjacency matrix of G but with labelled edges, where
          - labelled[i,j] = 1 iff i -> j is compelled, and
          - labelled[i,j] = -1 iff i -> j is reversible.

    """
    # Validate the input
    if not is_dag(ordered):
        raise ValueError("The given graph is not a DAG")
    no_edges = (ordered != 0).sum()
    if sorted(ordered[ordered != 0]) != list(range(1, no_edges + 1)):
        raise ValueError("The ordering of edges is not valid:", ordered[ordered != 0])
    # define labels: 1: compelled, -1: reversible, -2: unknown
    COM, REV, UNK = 1, -1, -2
    labelled = (ordered != 0).astype(int) * UNK
    # while there are unknown edges
    while (labelled == UNK).any():
        # print(labelled)
        # let (x,y) be the unknown edge with lowest order
        # (i.e. appears last in the ordering, NOT has smalles label)
        # in ordered
        unknown_edges = (ordered * (labelled == UNK).astype(int)).astype(float)
        unknown_edges[unknown_edges == 0] = -np.inf
        # print(unknown_edges)
        (x, y) = np.unravel_index(np.argmax(unknown_edges), unknown_edges.shape)
        # print(x,y)
        # iterate over all edges w -> x which are compelled
        Ws = np.where(labelled[:, x] == COM)[0]
        end = False
        for w in Ws:
            # if w is not a parent of y, label all edges into y as
            # compelled, and finish this pass
            if labelled[w, y] == 0:
                labelled[list(pa(y, labelled)), y] = COM
                end = True
                break
            # otherwise, label w -> y as compelled
            else:
                labelled[w, y] = COM
        if not end:
            # if there exists an edge z -> y such that z != x and z is
            # not a parent of x, label all unknown edges (this
            # includes x -> y) into y with compelled; label with
            # reversible otherwise.
            z_exists = len(pa(y, labelled) - {x} - pa(x, labelled)) > 0
            unknown = np.where(labelled[:, y] == UNK)[0]
            assert x in unknown
            labelled[unknown, y] = COM if z_exists else REV
    return labelled



# --------------------------------------------------------------------
# General utilities

def pywhy_cpdag_to_numpy(cpdag):
    try: 
        d = len(cpdag.nodes())
        A = np.zeros((d, d))
        # Add directed edges
        for e in list(cpdag.edges()["directed"]):
            i, j = e
            A[i, j] = 1
        # Add undirected edges
        for e in list(cpdag.edges()["undirected"]):
            i, j = e
            A[i, j] = 1
            A[j, i] = 1
        return A
    except:
        return cpdag


def pywhy_pag_to_numpy(pag):
    d = len(pag.nodes())
    A = np.zeros((d, d))
    # Add directed edges
    for e in list(pag.edges()["directed"]):
        i, j = e
        A[i, j] = 1
    # Add undirected edges
    for e in list(pag.edges()["undirected"]):
        i, j = e
        A[i, j] = 1
        A[j, i] = 1
    # Add bidirected edges
    for e in list(pag.edges()["bidirected"]):
        i, j = e
        A[i, j] = -1
        A[j, i] = -1
    for e in list(pag.edges()["circle"]):
        i,j = e
        A[i, j] = 2
    return A


def pywhy_dag_to_numpy(dag):
    try: 
        d = len(dag.nodes())
        A = np.zeros((d, d))
        for e in list(dag.edges()["directed"]):
            i, j = e
            A[i, j] = 1
        return A
    except:
        return dag


def dag_to_pag(A_dag):
    pass



def is_dag(A):
    """Checks wether the given adjacency matrix corresponds to a DAG.

    Parameters
    ----------
    A : np.array
        the adjacency matrix of the graph, where A[i,j] != 0 => i -> j.

    Returns
    -------
    is_dag : bool
        if the adjacency corresponds to a DAG
    """
    G = nx.from_numpy_array(A, create_using=nx.DiGraph)
    is_dag = nx.is_directed_acyclic_graph(G)
    return is_dag 

def pa(i, A):
    """The parents of i in A.

    Parameters
    ----------
    i : int
        the node's index
    A : np.array
        the adjacency matrix of the graph, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.

    Returns
    -------
    nodes : set of ints
        the parent nodes

    """
    return set(np.where(np.logical_and(A[:, i] != 0, A[i, :] == 0))[0])


def ch(i, A):
    """The children of i in A.

    Parameters
    ----------
    A : np.array
        the adjacency matrix of the graph, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.

    Returns
    -------
    nodes : set of ints
        the children nodes

    """
    return set(np.where(np.logical_and(A[i, :] != 0, A[:, i] == 0))[0])


def neighbors(i, A):
    """The neighbors of i in A, i.e. all nodes connected to i by an
    undirected edge.

    Parameters
    ----------
    i : int
        the node's index
    A : np.array
        the adjacency matrix of the graph, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.

    Returns
    -------
    nodes : set of ints
        the neighbor nodes

    """
    return set(np.where(np.logical_and(A[i, :] != 0, A[:, i] != 0))[0])


def adj(i, A):
    """The adjacent nodes of i in A, i.e. all nodes connected by a
    directed or undirected edge.
    Parameters
    ----------
    i : int
        the node's index
    A : np.array
        the adjacency matrix of the graph, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.

    Returns
    -------
    nodes : set of ints
        the adjacent nodes

    """
    return set(np.where(np.logical_or(A[i, :] != 0, A[:, i] != 0))[0])


def sort(L, order=None):
    """Sort the elements in an iterable according to its pre-defined
    'sorted' function, or according to a given order: i will precede j
    if i precedes j in the order.

    Parameters
    ----------
    L : iterable
        the iterable to be sorted
    order : iterable or None, optional
        a given ordering. In the sorted result, i will precede j if i
        precedes j in order. If None, the predefined 'sorted' function
        of the iterator will be used. Defaults to None.

    Returns
    -------
    ordered : list
        a list containing the elements of L, sorted from lesser to
        greater or according to the given order.

    """
    L = list(L)
    if order is None:
        return sorted(L)
    else:
        order = list(order)
        pos = np.zeros(len(order), dtype=int)
        pos[order] = range(len(order))
        positions = [pos[l] for l in L]
        return [tup[1] for tup in sorted(zip(positions, L))]


def only_directed(P):
    """
    Return the graph with the same nodes as P and only its directed edges.

    Parameters
    ----------
    P : np.array
        adjacency matrix of a graph

    Returns
    -------
    G : np.array
        adjacency matrix of the graph with the same nodes as P and
        only its directed edges

    """
    mask = np.logical_and(P != 0, P.T == 0)
    G = np.zeros_like(P)
    # set to the same values in case P is a weight matrix and there is
    # interest in maintaining the weights
    G[mask] = P[mask]
    return G


def topological_ordering(A):
    """Return a topological ordering for the DAG with adjacency matrix A,
    using Kahn's 1962 algorithm.

    Raises a ValueError exception if the given adjacency does not
    correspond to a DAG.

    Parameters
    ----------
    A : np.array
        the adjacency matrix of the graph, where A[i,j] != 0 => i -> j.

    Returns
    -------
    ordering : list of ints
        a topological ordering for the DAG

    Raises
    ------
    ValueError :
        If the given adjacency does not correspond to a DAG.

    """
    # Check that there are no undirected edges
    if only_undirected(A).sum() > 0:
        raise ValueError("The given graph is not a DAG")
    # Run the algorithm from the 1962 paper "Topological sorting of
    # large networks" by AB Kahn
    A = A.copy()
    sinks = list(np.where(A.sum(axis=0) == 0)[0])
    ordering = []
    while len(sinks) > 0:
        i = sinks.pop()
        ordering.append(i)
        for j in ch(i, A):
            A[i, j] = 0
            if len(pa(j, A)) == 0:
                sinks.append(j)
    # If A still contains edges there is at least one cycle
    if A.sum() > 0:
        raise ValueError("The given graph is not a DAG")
    else:
        return ordering


def only_undirected(P):
    """
    Return the graph with the same nodes as P and only its undirected edges.

    Parameters
    ----------
    P : np.array
        adjacency matrix of a graph

    Returns
    -------
    G : np.array
        adjacency matrix of the graph with the same nodes as P and
        only its undirected edges

    """
    mask = np.logical_and(P != 0, P.T != 0)
    G = np.zeros_like(P)
    # set to the same values in case P is a weight matrix and there is
    # interest in maintaining the weights
    G[mask] = P[mask]
    return G


#######################################################
#################### Generic utils ####################
#######################################################
def init_scancel_script(script_path : str):
    with open(script_path, "w") as f:
        f.write("#!/bin/bash\n")

def is_cpdag(method : str):
    """
    Return True i output of the given meethod is a CPDAG
    """
    if method in ["ges", "pc"]:
        return True
    return False


def tunable_parameters(method):
    """List of tunable parameters for each method of the benchmarking
    """
    methods_parameters = {
        "ges" : "lambda",
        "pc" : "alpha",
        "score" : "alpha",
        "cam" : "alpha",
        "nogam" : "alpha",
        "das" : "alpha", 
        "diffan" : "alpha",
        "grandag" : "alpha",
        "resit" : "alpha",
        "lingam" : "none",
        "random" : "none",
        "varsort" : "none",
        "scoresort" : "none"
    }
    return methods_parameters[method]


def scenario_param_folder(scenario : str, param : Union[float, str]):
    """Return <scenario>_<param> if scenario is not vanilla.
    Else, return <scenario>
    """
    assert scenario in ["vanilla", "confounded", "linear"]

    if scenario == "vanilla":
        return "vanilla"
    else:
        return scenario + f"_{param}"
    

def directed_np2nx(A : np.array):
    return nx.from_numpy_array(A, create_using=nx.DiGraph)


def is_a_collider(A : np.array, p1 : int, p2 : int, c : int) -> bool:
    """
    Check if p1, p2, and c form a collider in A

    Paramaters
    ----------
    A : np.array
        Adj. matrix representation of a DAG
    p1 : int
        First parent of the potential collider
    p2 : int
        Second parent of the potential collider
    c : int
        Head of the potential collider
    """
    # Check p1 -> c and p2 --> c
    collider_struct = A[p1, c] == 1 and A[p2, c] == 1
    return collider_struct


def moralize_triplet(A : np.array, p1 : int, p2 : int, c : int):
    """Moralize triplets of node.
    If the moralizing edge introduces an acyclicity, no edge is added
    """
    def is_new_edge_acyclic(A, u, v):
        A_clone = np.copy(A)
        A_clone[u, v] = 1
        return nx.is_directed_acyclic_graph(directed_np2nx(A_clone))

    if is_a_collider(A, p1, p2, c):
        if is_new_edge_acyclic(A, p1, p2):
            A[p1, p2] = 1
        elif is_new_edge_acyclic(p2, p1):
            A[p2, p1] = 1
    return A        


def read_params_from_file(params_id, params_file):
    """Read parameters from json file params_file, and return them in a dictionary.
    Use params_id as key for selection of parameters configuration
    """
    with open(params_file, "r") as f:
        json_content = json.load(f)
    params = json_content[str(params_id)]
    return params

def get_inference_params(tuning_base_dir, dataset_id, method, params_grid_file, reg_param_name, reg_param) -> Dict:
    """
    Return dictionary of parameters for the inference task. 
    If tuning_dir does not exist for a method, return {reg_param_name : reg_param}

    Parameters
    ----------
    tuning_dir : str
        Path to the directory with the tuning score for a specific combination of method and dataset.
        E.g. tuning_results_dir = /efs/tmp/.../tuning/ER/gauss/vanilla/vanilla/diffan/100_small_dense/
        Find the ID of the parameters combination with smallest val_score
    dataset_id : int
        Id of the inference dataset
    method : str
        The name of the inference method. If "lingam", return empty dictionary
    params_grid_file : str
        File with the param_grid of the current method, indexed by ID.
        If the method has no tuning params, the file does not exist and is ignored
    reg_param_name : str
        Name of the parameter determining sparsity of the solution. E.g. "alpha", "lambda"
    reg_param : float
        Value of the regularization parameter
    """
    if method == "lingam":
        return dict()

    # Find an existing tuning_dir, if any
    tuning_dir = os.path.join(tuning_base_dir, f"dataset_{dataset_id}")
    if not os.path.exists(tuning_dir):
        for i in range(20): # 20 number of seeds
            if os.path.join(tuning_base_dir, f"dataset_{i}"):
                tuning_dir = os.path.join(tuning_base_dir, f"dataset_{i}")
                break

    
    if os.path.exists(tuning_dir):
        # Find the ID of the minimum validation score
        argmin_id = None
        min_val_score = np.inf
        for json_file in os.listdir(tuning_dir):
            with open(os.path.join(tuning_dir, json_file), "r") as f:
                json_content = json.load(f)
                params_id = json_content["id"]
                val_score = json_content["val_score"]
                if val_score < min_val_score:
                    min_val_score = val_score
                    argmin_id = params_id

        # Take parameters configuration from params_grid_file
        with open(params_grid_file, "r") as f:
            json_content = json.load(f)
            parameters = json_content[str(argmin_id)]
        parameters[reg_param_name] = reg_param

    else:
        parameters = {reg_param_name : reg_param}

    return parameters


def has_order(method):
    if method in ["das", "score", "nogam", "cam", "diffan", "grandag", "lingam", "resit", "random", "varsort", "scoresort"]:
        return True
    return False


def methods(lingam : bool):
    if lingam:
        return ["random", "das", "cam", "score", "nogam", "diffan", "grandag", "resit", "lingam", "pc", "ges"]
    else:
        return ["random", "das", "cam", "score", "nogam", "diffan", "grandag", "resit", "pc", "ges"]


def graph_types():
    return ["ER", "SF", "GRP", "FC"]


def graph_densities(graph_type):
    if graph_type == "GRP":
        return ["cluster"]
    if graph_type == "FC":
        return ["full"]
    elif graph_type in ["ER", "SF"]:
        return ["sparse", "dense"]
    

def graph_sizes(graph_type):
    if graph_type in ["ER", "FC"]:
        return ["small", "medium", "large20", "large50"]
    elif graph_type in ["GRP", "SF"]:
        return ["medium", "large20", "large50"]

def graph_sizes_to_nodes():
    return {
        "small" : 5,
        "medium" : 10,
        "large20" : 20,
        "large50" : 50
    }

def noise_distributions():
    return ["gauss", "nonlin_weak", "nonlin_mid", "nonlin_strong"]

def inference_scenarios(vanilla, lingam):
    scenarios = ["vanilla", "linear", "measure_err", "confounded", "timino", "unfaithful", "pnl"]
    if not vanilla:
        scenarios.remove("vanilla")
    if not lingam:
        scenarios.remove("linear")
    return scenarios

def scenario_configs(scenario):
    scenario_params_dict = {
        "vanilla" : ["vanilla"],
        "timino" : ["timino"],
        "linear" : ["linear_0.33", "linear_0.66", "linear_0.99"],
        "unfaithful" : ["unfaithful_0.25", "unfaithful_0.5", "unfaithful_0.75", "unfaithful_1.0"],
        "confounded" : ["confounded_0.1", "confounded_0.2"],
        "measure_err" : ["measure_err_0.2", "measure_err_0.4", "measure_err_0.6", "measure_err_0.8"],
        "pnl" : ["pnl_3.0"]
    }
    return scenario_params_dict[scenario]