from modules.args import get_args
from modules.algorithms.base import CausalDiscovery
from modules.utils import get_data

import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

def plot_graph(adj, path):
    plt.figure(figsize=(10,10))
    nx.draw_networkx(nx.DiGraph(pd.DataFrame(adj)))
    plt.savefig(path)
    plt.close('all')

args = get_args()
base_folder = "/home/francescom/Research/DAS-Extension/src"

noise_type = args.noise_type
graph_type = args.graph_type
noise_std = args.nstd
d = args.d
s0 = args.s0
N = args.N
GP = args.GP
real_res = args.real_res

X, R, A = get_data(graph_type, d, s0, N, noise_std=noise_std, noise_type=noise_type, GP=GP)
if not real_res:
    R = None

algorithm = CausalDiscovery(X, R, A, **vars(args))
A_pred, top_order, order_time, exe_time = algorithm.inference()

plot_graph(A, f"{base_folder}/plots/graph.png")
plot_graph(A_pred, f"{base_folder}/plots/pred-graph.png")