import pytest
import os
import random
import math
import numpy as np
import networkx as nx
from sklearn.linear_model import LinearRegression

from benchmark.data_generators import VanillaGenerator, MeasureErrorGenerator, TiminoGenerator, UnfaithfulGenerator, ConfundedGenerator
from utils._data import DataSimulator
from utils._random_graphs import fully_connected, GaussianRandomPartition

seed = 42
random.seed(seed)
np.random.seed(seed)


##################### FIXTURES #####################

@pytest.fixture
def dag_sample():
    A = np.array([
        [0, 0, 1, 0, 0, 1],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]
    ])
    return A


##################### Test acyclicity and order #####################
def test_acyclicity_nontrivial_order():
    """Test that the returned graph have nontrivial topological order
    """
    def num_errors(order, adj):
        err = 0
        for i in range(len(order)):
            err += adj[order[i+1:], order[i]].sum()
        return err

    def node_to_size(n):
        if n <= 5:
            return "small"
        elif n <= 10:
            return "medium"
        elif n == 20:
            return "large20"
        return "large50"

    for num_nodes in range(5, 100):
        generator = VanillaGenerator(
            graph_type=np.random.choice(["ER", "SF"]),
            num_nodes = num_nodes,
            graph_size = node_to_size(num_nodes),
            graph_density="dense",
            num_samples=10,
            noise_distr="gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        trivial_order = range(num_nodes)
        assert num_errors(trivial_order, A) > 0
        assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph))


##################### Test ER graphs generation #####################

def test_er_small_sparse():
    """Test generation of small sparse ER graphs.
    Expected values:
        - number of nodes: 5
        - minimum number of edges: 2 (enforced by the generator)
        - maximum number of edges: <=5 (probability of 5 or more edges is 0.01)
        - average number of edges: \in [2, 3]
        - regularity of edges: on average max out and income degree difference is less than 1
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="ER",
            num_nodes = 5,
            graph_size = "small",
            graph_density="sparse",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] +  [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 5
    assert np.mean(logs["num_edges"]) <= 3
    assert np.min(logs["num_edges"]) == 2
    assert np.max(logs["num_edges"]) <= 5
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 1 # regularity condition


def test_er_small_dense():
    """Test generation of small dense ER graphs.
    Test specifics:
        - number of nodes: 5
        - maximum number of edges: >= 8 (probability of 8 or more edges is 0.2)
        - average number of edges: \in [4, 6]
        - regularity of edges: on average max out and income degree difference is less than 1
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="ER",
            num_nodes = 5,
            graph_size = "small",
            graph_density="dense",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 5
    assert np.mean(logs["num_edges"]) >= 4 and np.mean(logs["num_edges"]) <= 6 # mean
    assert np.max(logs["num_edges"]) >= 8 # density
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 1 # regularity condition


def test_er_medium_sparse():
    """Test generation of medium sparse ER graphs (10 nodes).
    Test specifics:
        - number of nodes: 10
        - density parameter: 2
        - mean number of edges: 10 +- 1
        - maximum number of edges: 10 +- 1
        - regularity of edges: on average max out and income degree difference is less than 1
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="ER",
            num_nodes = 10,
            graph_size = "medium",
            graph_density="sparse",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert generator.get_density_param() == 1
    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 10
    assert np.mean(logs["num_edges"]) >= 9 and np.mean(logs["num_edges"]) <= 11 # mean edges
    assert np.max(logs["num_edges"]) >= 9 and np.max(logs["num_edges"]) <= 11 # max edges
    assert np.min(logs["num_edges"]) >= 9 and np.min(logs["num_edges"]) <= 11 # min edges
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 1 # regularity condition


def test_er_medium_dense():
    """Test generation of medium dense ER graphs (10 nodes).
    Test specifics:
        - number of nodes: 10
        - density parameter: 2
        - mean number of edges: 20 +- 1
        - maximum number of edges: 20 +- 1
        - regularity of edges: on average max out and income degree difference is less than 1.5
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="ER",
            num_nodes = 10,
            graph_size = "medium",
            graph_density="dense",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert generator.get_density_param() == 2
    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 10
    assert np.mean(logs["num_edges"]) >= 19 and np.mean(logs["num_edges"]) <= 21 # mean edges
    assert np.max(logs["num_edges"]) >= 19 and np.max(logs["num_edges"]) <= 21 # max edges
    assert np.min(logs["num_edges"]) >= 19 and np.min(logs["num_edges"]) <= 21 # min edges
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 1.5 # regularity condition


def test_er_large_sparse():
    """Test generation of large sparse ER graphs (30 nodes).
    Test specifics:
        - number of nodes: 30
        - density parameter: 1
        - mean number of edges: 30 +- 1
        - maximum number of edges: 30 +- 1
        - regularity of edges: on average max out and income degree difference is less than 1.5
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="ER",
            num_nodes = 20,
            graph_size = "large20",
            graph_density="sparse",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert generator.get_density_param() == 1
    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 20
    assert np.mean(logs["num_edges"]) >= 19 and np.mean(logs["num_edges"]) <= 21 # mean edges
    assert np.mean(logs["num_edges"]) >= 19 and np.mean(logs["num_edges"]) <= 21 # max edges
    assert np.mean(logs["num_edges"]) >= 19 and np.mean(logs["num_edges"]) <= 21 # min edges
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 1.5 # regularity condition


def test_er_large_dense():
    """Test generation of large dense ER graphs (30 nodes).
    Test specifics:
        - number of nodes: 30
        - density parameter: 4
        - mean number of edges: 30 +- 1
        - maximum number of edges: 30 +- 1
        - regularity of edges: on average max out and income degree difference is less than 1.5
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="ER",
            num_nodes = 20,
            graph_size = "large20",
            graph_density="dense",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert generator.get_density_param() == 4
    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 20
    assert np.mean(logs["num_edges"]) >= 79 and np.mean(logs["num_edges"]) <= 81 # mean edges
    assert np.mean(logs["num_edges"]) >= 79 and np.mean(logs["num_edges"]) <= 81 # max edges
    assert np.mean(logs["num_edges"]) >= 79 and np.mean(logs["num_edges"]) <= 81 # min edges
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 2 # regularity condition


def test_er_50_density_param():
    """Test that for 50 nodes the density parameters change correctly
    Test specifics: 
        - sparse: m=2
        - sparse: m=8
    """ 
    generator = VanillaGenerator(
        graph_type="ER",
        num_nodes = 50,
        graph_size = "large50",
        graph_density="sparse",
        num_samples=10,
        noise_distr="Gauss",
        noise_std_support=(.5, 1.),
        seed=seed
    )

    assert generator.get_density_param() == 2

    generator = VanillaGenerator(
        graph_type="ER",
        num_nodes = 50,
        graph_size = "large50",
        graph_density="dense",
        num_samples=10,
        noise_distr="Gauss",
        noise_std_support=(.5, 1.),
        seed=seed
    )

    assert generator.get_density_param() == 8


##################### Test SF graphs generation #####################
def test_sf_small_error():
    """Test that requiring small SF graph gives error
    """ 
    n_graphs = 100
    generator = VanillaGenerator(
        graph_type="SF",
        num_nodes = 5,
        graph_size = "small",
        graph_density="sparse",
        num_samples=10,
        noise_distr="Gauss",
        noise_std_support=(.5, 1.),
        seed=seed
    )
    with pytest.raises(AssertionError):
        generator.simulate_dag()
        A = generator.adjacency

def test_sf_medium_sparse():
    """Test generation of medium sparse SF graphs (10 nodes).
    Test specifics:
        - number of nodes: 10
        - density parameter: 2
        - mean number of edges: 10 +- 1
        - maximum number of edges: 10 +- 1
        - low maximum of input degree: <= 2
        - large minimum of max_output_degree: <= 4
        - large mean distance between max input and output degree: >= 3
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="SF",
            num_nodes = 10,
            graph_size = "medium",
            graph_density="sparse",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert generator.get_density_param() == 1
    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 10
    assert np.mean(logs["num_edges"]) >= 9 and np.mean(logs["num_edges"]) <= 11 # mean edges
    assert np.max(logs["num_edges"]) >= 9 and np.max(logs["num_edges"]) <= 11 # max edges
    assert np.min(logs["num_edges"]) >= 9 and np.min(logs["num_edges"]) <= 11 # min edges
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) >= 3 # irregular graph
    assert np.mean(np.array(logs["max_income_degree"])) <= 2 # low input degree
    assert np.mean(np.array(logs["max_outcome_degree"])) >= 4 # large output degree


def test_sf_large_sparse():
    """Test generation of large sparse SF graphs (30 nodes).
    Test specifics:
        - number of nodes: 10
        - density parameter: 2
        - mean number of edges: 10 +- 1
        - maximum number of edges: 10 +- 1
        - low maximum of input degree: <= 2
        - large minimum of max_output_degree: >= 6
        - large mean distance between max input and output degree: >= 3
    """ 
    n_graphs = 100
    run_seeds = np.random.randint(10000, size=n_graphs)
    logs = {
        "num_nodes" : [],
        "max_income_degree" : [],
        "max_outcome_degree" : [],
        "num_edges" : [],
    }
    for seed in run_seeds:
        generator = VanillaGenerator(
            graph_type="SF",
            num_nodes = 20,
            graph_size = "large20",
            graph_density="sparse",
            num_samples=10,
            noise_distr="Gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.simulate_dag()
        A = generator.adjacency
        num_nodes = A.shape[0]
        num_edges = A.sum()
        max_income_degree = np.max(A.sum(axis=0))
        max_outcome_degree = np.max(A.sum(axis=1))
        logs["num_nodes"] = logs["num_nodes"] + [num_nodes]
        logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
        logs["max_outcome_degree"] = logs["max_outcome_degree"] + [max_outcome_degree]
        logs["num_edges"] = logs["num_edges"] + [num_edges]

    assert generator.get_density_param() == 1
    assert min(logs["num_nodes"]) == max(logs["num_nodes"]) & max(logs["num_nodes"]) == 20
    assert np.mean(logs["num_edges"]) >= 19 and np.mean(logs["num_edges"]) <= 21 # mean edges
    assert np.max(logs["num_edges"]) >= 19 and np.max(logs["num_edges"]) <= 21 # max edges
    assert np.min(logs["num_edges"]) >= 19 and np.min(logs["num_edges"]) <= 21 # min edges
    assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) >= 5 # irregular graph
    assert np.mean(np.array(logs["max_income_degree"])) <= 1 # low input degree
    assert np.mean(np.array(logs["max_outcome_degree"])) >= 10 # large output degree


def test_sf_50_density_param():
    """Test that for 50 nodes the density parameters change correctly
    Test specifics: 
        - sparse: m=2
        - sparse: m=8
    """ 
    generator = VanillaGenerator(
        graph_type="SF",
        num_nodes = 50,
        graph_size = "large50",
        graph_density="sparse",
        num_samples=10,
        noise_distr="Gauss",
        noise_std_support=(.5, 1.),
        seed=seed
    )

    assert generator.get_density_param() == 2

    generator = VanillaGenerator(
        graph_type="SF",
        num_nodes = 50,
        graph_size = "large50",
        graph_density="dense",
        num_samples=10,
        noise_distr="Gauss",
        noise_std_support=(.5, 1.),
        seed=seed
    )

    assert generator.get_density_param() == 8


##################### Test GRP graphs generation #####################
def test_gpr_sample_number_of_clusters():
    p_in = 0.4
    p_out = 0.04
    
    # 10 nodes
    d = 10
    grp = GaussianRandomPartition(d, p_in, p_out)
    assert grp._sample_number_of_clusters() == 2, f"Medium graphs with {d} nodes must have 2 clusters!"

    # 20 nodes 
    d = 20
    grp = GaussianRandomPartition(d, p_in, p_out)
    for _ in range(10):
        n_clusters = grp._sample_number_of_clusters()
        assert n_clusters >= 3 and n_clusters <= 5, f"Medium graphs with {d} nodes must have n_clusters in [3, 4, 5]!"

    # 30 nodes 
    d = 30
    grp = GaussianRandomPartition(d, p_in, p_out)
    for _ in range(10):
        n_clusters = grp._sample_number_of_clusters()
        assert n_clusters >= 3 and n_clusters <= 5, f"Medium graphs with {d} nodes must have n_clusters in [3, 4, 5]!"

    # 50 nodes 
    d = 50
    grp = GaussianRandomPartition(d, p_in, p_out)
    for _ in range(10):
        n_clusters = grp._sample_number_of_clusters()
        assert n_clusters >= 4 and n_clusters <= 6, f"Medium graphs with {d} nodes must have n_clusters in [4, 5, 6]!"


def test_sample_cluster_sizes():
    p_in = 0.4
    p_out = 0.04
    
    for d in [10, 20, 30, 50]:
        for _ in range(10):
            grp = GaussianRandomPartition(d, p_in, p_out)
            n_clusters = grp._sample_number_of_clusters()
            assert np.min(grp._sample_cluster_sizes(n_clusters)) >= 3, f"Unexpected behaviour! There is a cluster with less than 3 nodes!"


def test_clusters_er_behaviour():
    """Check that clusters are approimatively regular graphs
    """
    import torch
    torch.manual_seed(seed)
    p_in = 0.4
    p_out = 0.04
    
    for num_nodes in [10, 20, 30, 50]:
        logs = {
            "num_nodes" : [],
            "max_income_degree" : [],
            "max_outcome_degree" : [],
            "num_edges" : [],
        }

        grp = GaussianRandomPartition(num_nodes, p_in, p_out)
        n_clusters = grp._sample_number_of_clusters()
        clusters_size = grp._sample_cluster_sizes(n_clusters)
        for c in clusters_size:
            A = grp._sample_er_cluster(cluster_size=c)
            max_income_degree = np.max(A.sum(axis=0))
            max_outcome_degree = np.max(A.sum(axis=1))
            logs["max_income_degree"] = logs["max_income_degree"] + [max_income_degree]
            logs["max_outcome_degree"] = logs["max_outcome_degree"] +  [max_outcome_degree]

        assert np.mean(abs(np.array(logs["max_income_degree"]) - np.array(logs["max_outcome_degree"]))) <= 2,\
        "ER clusters showing irregular behaviour!"


def test_sparsity_between_clusters():
    """Check the the disjoint union works as expetced
    1. Mark the nodes in cllusters before the union.
    2. Check connection in the same cluster, between different clusters.
    """
    p_in = 0.4
    p_out = {
        10 : 0.06,
        20 : 0.06,
        30 : 0.03,
        50 : 0.03
    }
    
    for num_nodes in [10, 20, 30, 50]:
        grp = GaussianRandomPartition(num_nodes, p_in, p_out[num_nodes])
        n_clusters = grp._sample_number_of_clusters()
        clusters_size = grp._sample_cluster_sizes(n_clusters)
        
        # Initialize with the first cluster and remove it from the list
        A = grp._sample_er_cluster(clusters_size[0])
        clusters_size = np.delete(clusters_size, [0])
        m = A.shape[0]
        for c in clusters_size:
            A = grp._disjoint_union(A, c)
            n = A.shape[0]

            # Check that between clusters connections are sparse (Upper trianguler matrix)
            A_between_clusters = A[:m, m:]
            assert A_between_clusters.sum() <= 2
            m=n

###################### Test FullyConnectedGenerator ######################
def test_fully_connected_toporder():
    """Test that the returned graph have nontrivial topological order
    """
    def num_errors(order, adj):
        err = 0
        for i in range(len(order)):
            err += adj[order[i+1:], order[i]].sum()
        return err
    
    for num_nodes in range(5, 20):
        A = fully_connected(num_nodes)
        trivial_order = range(num_nodes)
        assert num_errors(trivial_order, A) > 0


def test_fully_connected_graph():
    """Test that the generated graph is actually fully connected
    """
    for num_nodes in range(5, 20):
        A = fully_connected(num_nodes)
        d = A.shape[0]
        assert np.sum(A) == d*(d-1)/2

###################### Test ConfoundedGenerator ######################
def test_given_p_confounded_when_generating_graphs_then_avg_confounded_pairs_consistent():
    p_confounded = 0.2
    num_nodes = 20
    generator = ConfundedGenerator(
        rho=p_confounded,
        graph_type="ER",
        num_nodes = num_nodes,
        graph_size = "large20",
        graph_density="sparse",
        num_samples=2,
        noise_distr="Gauss",
        noise_std_support=(.5, 1.),
        seed=seed
    )

    sum_confounded_pairs = 0
    for seed in range(10):
        generator.generate_data()
        confounded_A = generator.confounded_adjacency
        # A = generator.adjacency

        n_confounded_pairs = np.sum(confounded_A[:num_nodes, :num_nodes])
        sum_confounded_pairs += n_confounded_pairs
    
    assert abs(np.mean(sum_confounded_pairs)/num_nodes - p_confounded) < 0.05

def test_given_p_confounded_when_generating_graphs_then_no_edge_between_latents():
    pass

def test_given_p_confounded_when_generating_graphs_then_confounders_are_sources():
    pass



###################### Test MeasureErrorGenerator ######################
def test_measure_error_generator(dag_sample):
    """TODO: write docstring!
    """
    import torch
    torch.manual_seed(seed)

    graph_type = "ER"
    noise_std_support = (0.5, 1)
    noise_distr = "gauss"
    num_nodes = dag_sample.shape[0]
    num_samples = 1000
    # Map gamma to % of variance explained by the error
    gamma_var_map = {
        0.2 : 0.04,
        0.4 : 0.14,
        0.6 : 0.26,
        0.8 : 0.39
    }

    sampler = DataSimulator(num_samples, 1, noise_std_support, noise_distr, dag_sample, GP=True, lengthscale=1)
    for gamma, _ in gamma_var_map.items():
        generator = MeasureErrorGenerator(
            gamma, graph_type, num_nodes, "medium", "sparse", num_samples, noise_distr, noise_std_support
        )

        # This generation is random! No good for testing...
        X = sampler.sample()
        x_std = torch.std(X, dim=0)
        X_noisy = generator.add_noise(torch.clone(X))
        x_noisy_std = torch.std(X_noisy, dim=0)

        noise = X_noisy-X
        noise_std = noise.std(dim=0)
        gamma_tilde = noise_std/x_std
        assert all(abs(gamma_tilde - gamma) < 0.05)
        assert all(abs(x_noisy_std/x_std - torch.sqrt(1+gamma_tilde**2)) < 0.05)



###################### Test LinearSCMGenerator ######################
def test_linear_mechanism(dag_sample):
    """Run linear regression on functions generated with linear and non linear mechanism
    Show that for liner mechanism error decrease
    """
    noise_std_support = (0.5, 1)
    noise_distr = "gauss"
    num_samples = 1000
    train_samples = int(num_samples*(2/3))

    sampler = DataSimulator(num_samples, 1, noise_std_support, noise_distr, dag_sample, GP=True, lengthscale=1)
    for _ in range(10):
        X = np.random.normal(0, 1, (num_samples, 1))
        y_nonlinear = sampler.sampleGP(X)
        y_linear = sampler.sample_linear_mechanism(X)
        score_low = LinearRegression().fit(X[:train_samples, :], y_nonlinear[:train_samples]).score(X[train_samples:, :], y_nonlinear[train_samples:])
        score_high = LinearRegression().fit(X[:train_samples, :], y_linear[:train_samples]).score(X[train_samples:, :], y_linear[train_samples:])
        assert score_high > score_low

###################### Test TiminoGenerator ######################
def test_timino_lagged_effect(dag_sample):
    """Test that TiMINo linearly adds lagged effect ofrevious observation 
    """
    import torch
    n, d = (10, 2)
    X = torch.reshape(torch.tensor(range(20)), (n, d))

    # Set arguments to None as we care only about make_timino() function
    timino = TiminoGenerator(
        graph_type=None, num_nodes=None, graph_size=None, graph_density=None, num_samples=None, noise_distr=None, noise_std_support=None
    )
    X_timino = timino.make_timino(torch.clone(X), np.array([1, 1]))

    for t in range(n):
        assert torch.equal(X_timino[t], torch.sum(X[:t+1], axis=0))


###################### Test UnfaithfulGenerator ######################
def test_fully_connected_moral_triplets():
    """In a fully connected matrix of d ndoes the number of unfaithful connections
    is given by the sum of binomial coefficients
    \sum_{p=0}^{d-1} [p(p-1)/2]
    """
    d=10
    generator = UnfaithfulGenerator(
        p_unfaithful=1,
        graph_type="FC",
        num_nodes=d,
        graph_size="medium",
        graph_density="full",
        num_samples=1,
        noise_distr="gauss",
        noise_std_support=(.5, 1)
    )

    generator.simulate_dag()
    assert np.sum(generator.adjacency) == d*(d-1)/2

    moral_colliders = generator.find_moral_colliders()
    num_colliders = 0  
    for p in range(d):
        num_colliders += p*(p-1)/2

    assert len(moral_colliders) == num_colliders


def test_consistent_unfaithful_adj():
    """Make small artificial graph with lot of moralities, and check faithfulness work as expected
    """
    d=5
    p_unfaithful = 1
    generator = UnfaithfulGenerator(
        p_unfaithful=p_unfaithful,
        graph_type="FC",
        num_nodes=d,
        graph_size="small",
        graph_density="full",
        num_samples=1,
        noise_distr="gauss",
        noise_std_support=(.5, 1)
    )

    generator.adjacency = np.triu(np.ones((5, 5)), k=1)
    unfaithful_adj, _ = generator.make_unfaithful_adj()

    target_adj = np.array([
        [0,1,0,0,0],
        [0,0,1,1,1],
        [0,0,0,1,0],
        [0,0,0,0,1],
        [0,0,0,0,0]
    ])
    assert np.allclose(unfaithful_adj, target_adj)


############################ Other tests #############################
def test_same_graph_different_sample_size():
    """Test that under same data configuration,datasets with different
    sample size share the same ground truth underlying graph.
    """
    base_folder = os.path.join(os.sep, "efs", "data", "hyperparameters", "ER", "gauss", "vanilla", "vanilla")
    configs = [
        "100_large20_dense",
        "1000_large20_dense",
    ]
    # 10 datasets for each config
    for id in range(20):
        data100 = np.genfromtxt(os.path.join(base_folder, configs[0], f"groundtruth{id}.csv"), delimiter=",")
        data1000 = np.genfromtxt(os.path.join(base_folder, configs[1], f"groundtruth{id}.csv"), delimiter=",")
        assert np.array_equal(data100,data1000)

def test_correct_graph_size():
    """Test that under same data configuration,datasets with different
    sample size share the same ground truth underlying graph.
    """
    base_folder = os.path.join(os.sep, "efs", "data",  "ER", "gauss", "vanilla", "vanilla")
    configs = {
        # "100_small_dense" : 5,
        # "100_medium_dense" : 10,
        "100_large20_sparse" : 20,
        # "100_large30_dense" : 30,
        # "100_large50_dense" : 50,
    }
    # 10 datasets for each config
    for graph, graph_size in configs.items():
        for id in range(20):
            data = np.genfromtxt(os.path.join(base_folder, graph, f"groundtruth{id}.csv"), delimiter=",")
            d = data.shape[0]
            assert d == graph_size

def test_correct_sample_size():
    """Test that under same data configuration,datasets with different
    sample size share the same ground truth underlying graph.
    """
    base_folder = os.path.join(os.sep, "efs", "data",  "ER", "gauss", "vanilla", "vanilla")
    configs = {
        "100_large20_sparse" : 100,
        # "1000_large20_sparse" : 1000,
    }
    # 10 datasets for each config
    for graph, num_samples in configs.items():
        for id in range(20):
            data = np.genfromtxt(os.path.join(base_folder, graph, f"data{id}.csv"), delimiter=",")
            n = data.shape[0]
            assert n == num_samples

