import os
import pytest
import numpy as np
import pandas as pd
import networkx as nx
from utils._utils import dag_to_cpdag
from utils._metrics import cpdag_fpr, cpdag_tpr, dag_fpr, dag_tpr


##################### FIXTURES #####################
@pytest.fixture
def data_dir():
    """Directory with 5 x 5 adjacency matrices for 10 distinct DAG
    """
    base_dir = os.path.join(os.sep, "home", "ec2-user", "causal-benchmark")
    return os.path.join(base_dir, "tmp", "test_data")

@pytest.fixture
def dag_sample():
    A = np.array([
        [0, 0, 1, 0, 0, 1],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0]
    ])
    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph))
    return A


@pytest.fixture
def cpdag_sample(dag_sample):
    return dag_to_cpdag(dag_sample)


##################### CPDAG Tests #####################
def test_tpr_all_correct(data_dir):
    """TPR with A_truth as prediction
    """
    for index in range(len(os.listdir(data_dir))):
        A_truth = np.genfromtxt(os.path.join(data_dir, f"data{index}.csv"), delimiter=",")
        cpdag_truth = dag_to_cpdag(A_truth)
        assert cpdag_tpr(cpdag_truth, cpdag_truth) == 1.0


def test_fpr_all_correct(data_dir):
    """FPR with A_truth as prediction
    """
    for index in range(len(os.listdir(data_dir))):
        A_truth = np.genfromtxt(os.path.join(data_dir, f"data{index}.csv"), delimiter=",")
        cpdag_truth = dag_to_cpdag(A_truth)
        assert cpdag_fpr(cpdag_truth, cpdag_truth) == 0.0


def test_tpr_empty(data_dir):
    """TPR with empty graph prediction
    """
    for index in range(len(os.listdir(data_dir))):
        A_truth = np.genfromtxt(os.path.join(data_dir, f"data{index}.csv"), delimiter=",")
        cpdag_truth = dag_to_cpdag(A_truth)
        assert cpdag_tpr(np.zeros(A_truth.shape), cpdag_truth) == 0.0


def test_fpr_empty(data_dir):
    """FPR with empty graph prediction
    """
    for index in range(len(os.listdir(data_dir))):
        A_truth = np.genfromtxt(os.path.join(data_dir, f"data{index}.csv"), delimiter=",")
        cpdag_truth = dag_to_cpdag(A_truth)
        assert cpdag_fpr(np.zeros(A_truth.shape), cpdag_truth) == 0.0


def test_dtop_smaller_shd():
    raw_logs_path = "/home/ec2-user/causal-benchmark/tmp/logs/inference/ER/vanilla/vanilla/das/raw_das.csv"

