import os
import pytest
import numpy as np
import pandas as pd
import networkx as nx
from utils._utils import dag_to_cpdag
from utils._metrics import cpdag_fpr, cpdag_tpr, dag_fpr, dag_tpr, dag_fp


from benchmark.data_generators import VanillaGenerator, MeasureErrorGenerator, TiminoGenerator, UnfaithfulGenerator, ConfundedGenerator
from utils._data import DataSimulator
from utils._random_graphs import fully_connected, GaussianRandomPartition


##################### FIXTURES #####################
@pytest.fixture
def data_dir():
    """Directory with 5 x 5 adjacency matrices for 10 distinct DAG
    """
    base_dir = os.path.join(os.sep, "home", "ec2-user", "causal-benchmark")
    return os.path.join(base_dir, "tmp", "test_data")

@pytest.fixture
def dag_sample():
    A = np.array([
        [0, 0, 1, 0, 0, 1],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0]
    ])
    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph))
    return A

@pytest.fixture
def dag_with_fp():
    # 4 false positives
    A = np.array([
        [0, 0, 1, 1, 1, 1],
        [0, 0, 1, 0, 1, 0],
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0]
    ])
    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph))
    return A


def test_given_p_confounded_when_generating_graphs_then_avg_confounded_pairs_consistent():
    p_confounded = 0.1
    num_nodes = 10

    p_confounded_hat = list()
    for seed in range(10):
        generator = ConfundedGenerator(
            rho=p_confounded,
            graph_type="ER",
            num_nodes = num_nodes,
            graph_size = "medium",
            graph_density="sparse",
            num_samples=2,
            noise_distr="gauss",
            noise_std_support=(.5, 1.),
            seed=seed
        )
        generator.generate_data()
        confounded_A = generator.confounded_adjacency
        # A = generator.adjacency

    
        p_confounded_hat.append(np.sum(confounded_A[:num_nodes, num_nodes:])/(num_nodes**2 - num_nodes))
        # print(np.sum(confounded_A[:num_nodes, num_nodes:]))

    assert abs(np.mean(p_confounded_hat) - p_confounded) < 0.01,\
        f"{np.mean(p_confounded_hat)}"