from collections import OrderedDict
import numpy as np

"""
Collection of directed acyclic graphs for testing the quality 
"""


# X -> T -> Y. Identifiable
def gen_backdoor():
    name = 'backdoor'
    graph = OrderedDict([
        (0, []),  # X
        (1, [0, -1]),  # T
        (2, [0, 1, -2]),  # Y
    ])
    do_var = 1
    target_var = 2
    n_latent = 2  # Total number of latent variables
    latent_dim = 1  # The dimension of latent noise variables
    return name, graph, do_var, target_var, n_latent, latent_dim


#      - - - -
#    /        \
#   T -> M -> Y     Identifiable
def gen_frontdoor():
    name = 'frontdoor'
    graph = OrderedDict([
        (0, [-1, -2]),
        (1, [0, -3]),
        (2, [1, -1, -4]),
    ])
    do_var = 0
    target_var = 2
    n_latent = 4
    latent_dim = 1
    return name, graph, do_var, target_var, n_latent, latent_dim


#           --
#         /   \
#   X -> T -> Y
def gen_iv():
    name = 'iv'
    graph = OrderedDict([
        (0, []),
        (1, [0, -1, -2]),
        (2, [1, -2, -3]),
    ])
    do_var = 1
    target_var = 2
    n_latent = 3
    latent_dim = 1
    return name, graph, do_var, target_var, n_latent, latent_dim


#       -------
#      /    -  \
#    /     / \ \
#   T -> X ->  Y    Non-identifiable
def gen_leaky():
    name = 'leaky'
    graph = OrderedDict([
        (0, [-1, -2]),
        (1, [0, -3, -4]),
        (2, [1, -1, -3, -5])
    ])

    do_var = 0
    target_var = 2
    n_latent = 5
    latent_dim = 1
    return name, graph, do_var, target_var, n_latent, latent_dim


# Non-identifiable
def gen_bow():
    name = 'bow'
    graph = OrderedDict([
        (0, [-1, -3]),
        (1, [0, -1, -2]),
    ])
    do_var = 0
    target_var = 1
    n_latent = 3
    latent_dim = 1
    return name, graph, do_var, target_var, n_latent, latent_dim


def gen_params(key):
    return {
        'backdoor': gen_backdoor,
        'frontdoor': gen_frontdoor,
        'bow': gen_bow,
        'leaky': gen_leaky,
        'iv': gen_iv,
    }[key]()