import os
import yaml
import json
import networkx as nx
from causal_discovery_llm import causal_discovery_llm
from utils.causalgraph_utils import save_graph_to_file
from utils.metrics import eval_metrics
import dotenv
# Load environment variables from .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)

# Get the SELECTED_MODEL_NAME variable
selected_model_name = os.getenv("SELECTED_MODEL_NAME")

# Determine the model_name based on SELECTED_MODEL_NAME
if selected_model_name == "openai":
    model_name = os.getenv("MODEL_NAME")
else:
    model_name = selected_model_name


if __name__ == "__main__":
    # Load config.yaml
    yaml_file = "src/config.yaml"
    with open(yaml_file, 'r') as file:
        config = yaml.safe_load(file)

    dataset_file = input("Please enter the dataset path and file name: ")

    causal_discovery_state = causal_discovery_llm(dataset_file=dataset_file, config=config)

    # Save the causal graph to a file, and get the file name
    graph_file_name = save_graph_to_file(
        input_file_name=dataset_file,
        causal_graph=causal_discovery_state["causal_graph"],
        filetype="json",
    )

    print(f"Learned graph saved to {graph_file_name}")

    # Load ground truth graph
    ground_truth_file = os.path.splitext(dataset_file)[0] + "_ground_truth.json"
    with open(ground_truth_file, "r") as f:
        ground_truth_graph_dict = json.load(f)

    # Call the eval_metrics function to evaluate the causal discovery
    ground_truth_graph = nx.DiGraph(ground_truth_graph_dict["edges"])

    # Evaluate the metrics and save them in a file
    if config["fci_conquer"]:
        from utils.causalgraph_utils import edgelist_to_generalgraph
        ground_truth_digraph = edgelist_to_generalgraph(ground_truth_graph_dict["edges"])
        learned_causal_graph = causal_discovery_state['causal_graph']
    else:
        ground_truth_digraph = nx.DiGraph(ground_truth_graph_dict["edges"])
        learned_causal_graph = nx.DiGraph(causal_discovery_state[1]['Causal Discovery']['causal_graph'])

    metrics_file_path = eval_metrics(
        save_path=os.path.dirname(graph_file_name),
        ground_truth_graph=ground_truth_digraph,
        learned_causal_graph=learned_causal_graph,
        dataset_name=os.path.basename(dataset_file),
        discovery_time=causal_discovery_state["elapsed_time"],
        input_token_count=causal_discovery_state['input_token_count'],
        output_token_count=causal_discovery_state['output_token_count'],
        tool_calls=causal_discovery_state['tool_calls'],
        model_name=model_name,
        tools=config["tool_list"],
    )


    print(f"Metrics saved to {metrics_file_path}")

