import time
from typing import Optional
import networkx as nx
import numpy as np
from torch.utils.data import Dataset
from causallearn.search.ConstraintBased.PC import pc
import wandb

from .base._base_model import BaseModel
import numpy as np
import pandas as pd



def createFullyConnectedGraph(topological_order):
    n = len(topological_order)
    adj_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(i + 1, n):
            adj_matrix[topological_order[i], topological_order[j]] = 1

    return adj_matrix

class PC(BaseModel):
    def __init__(self):
        super().__init__()
        self._adj_matrix = None

    def train(
        self,
        dataset: Dataset,
        log_wandb: bool = False,
        wandb_project: str = "pc",
        wandb_config_dict: Optional[dict] = None,
        **model_kwargs,
    ):
        data = dataset.tensors[0].numpy()

        if log_wandb:
            wandb_config_dict = wandb_config_dict or {}
            wandb.init(
                project=wandb_project,
                name="pc",
                config=wandb_config_dict,
            )

        
        start = time.time()
        d = data.shape[1]
        cg_with_background_knowledge = pc(data)
        directed = cg_with_background_knowledge.find_fully_directed()
        pc_graph = nx.DiGraph()
        pc_graph.add_nodes_from(range(d))
        for (i, j) in directed:
            if not nx.has_path(pc_graph, j, i):
                pc_graph.add_edge(i, j)
        self._adj_matrix = nx.to_numpy_array(pc_graph)
        
        self._train_runtime_in_sec = time.time() - start



    def get_adjacency_matrix(self, threshold: bool = True) -> np.ndarray:
        return self._adj_matrix
