import os
import subprocess
import graph


def graph_from_tetrad(file_path):
    dir_edges = []
    undir_edges = []
    with open(file_path, "r") as file:
        status = "waiting"
        while line := file.readline():
            line = line.strip()
            if line.startswith("Graph Nodes:"):
                status = "read_nodes"
                continue
            if status == "read_nodes":
                p = line.count(";") + 1
                status = "waiting"
                continue
            if line.startswith("Graph Edges:"):
                status = "read_edges"
                continue
            if status == "read_edges" and not line:
                break
            if status == "read_edges":
                split_line = line.split()
                a = int(split_line[1][1:])
                b = int(split_line[3][1:])
                if split_line[2] == "---":
                    undir_edges.append((a, b))
                else:
                    dir_edges.append((a, b))

    return graph.CausalGraph(p, dir_edges, undir_edges, "cpdag")


def tetrad_cleanup(output_file):
    if os.path.exists(output_file):
        os.remove(output_file)
    if os.path.exists("causal-cmd.log"):
        os.remove("causal-cmd.log")


def tetrad_pc(path_to_tetrad, data_path, alpha):
    command = [
        "java",
        "-jar",
        path_to_tetrad,
        "--algorithm",
        "pc",
        "--dataset",
        data_path,
        "--delimiter",
        "comma",
        "--data-type",
        "continuous",
        "--test",
        "fisher-z-test",
        "--alpha",
        str(alpha),
        "--prefix",
        "pc",
        "--default",
    ]
    _ = subprocess.run(command, capture_output=True, check=True, text=True)
    g = graph_from_tetrad("pc_out.txt")
    tetrad_cleanup("pc_out.txt")
    return g


def tetrad_ges(path_to_tetrad, data_path, penalty):
    command = [
        "java",
        "-jar",
        path_to_tetrad,
        "--algorithm",
        "fges",
        "--dataset",
        data_path,
        "--delimiter",
        "comma",
        "--data-type",
        "continuous",
        "--score",
        "sem-bic-score",
        "--penaltyDiscount",
        str(penalty),
        "--prefix",
        "ges",
        "--precomputeCovariances",
        "--default",
    ]
    _ = subprocess.run(command, capture_output=True, check=True, text=True)
    g = graph_from_tetrad("ges_out.txt")
    tetrad_cleanup("ges_out.txt")
    return g


def tetrad_boss(path_to_tetrad, data_path, penalty, num_starts):
    command = [
        "java",
        "-jar",
        path_to_tetrad,
        "--algorithm",
        "boss",
        "--dataset",
        data_path,
        "--delimiter",
        "comma",
        "--data-type",
        "continuous",
        "--score",
        "sem-bic-score",
        "--penaltyDiscount",
        str(penalty),
        "--numStarts",
        str(num_starts),
        "--prefix",
        "boss",
        "--precomputeCovariances",
        "--resamplingWithReplacement",
        "--addOriginalDataset",
    ]
    _ = subprocess.run(command, capture_output=True, check=True, text=True)
    g = graph_from_tetrad("boss_out.txt")
    tetrad_cleanup("boss_out.txt")
    return g


def tetrad_grasp(path_to_tetrad, data_path, penalty):
    command = [
        "java",
        "-jar",
        path_to_tetrad,
        "--algorithm",
        "grasp",
        "--dataset",
        data_path,
        "--delimiter",
        "comma",
        "--data-type",
        "continuous",
        "--score",
        "sem-bic-score",
        "--test",
        "fisher-z-test",
        "--penaltyDiscount",
        str(penalty),
        "--prefix",
        "grasp",
        "--default",
    ]
    _ = subprocess.run(command, capture_output=True, check=True, text=True)
    g = graph_from_tetrad("grasp_out.txt")
    tetrad_cleanup("grasp_out.txt")
    return g


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 5:
        print(
            "Usage: python run_tetrad.py <path_to_tetrad> <data.csv> <algorithm> <parameter (penalty/alpha)>"
        )
        sys.exit(1)

    path_to_tetrad = sys.argv[1]
    data_path = sys.argv[2]
    algorithm = sys.argv[3]
    parameter = sys.argv[4]

    match algorithm:
        case "pc":
            g = tetrad_pc(path_to_tetrad, data_path, parameter)
        case "ges":
            g = tetrad_ges(path_to_tetrad, data_path, parameter)
        case "boss":
            num_starts = sys.argv[5]
            g = tetrad_boss(path_to_tetrad, data_path, parameter, num_starts)
        case "grasp":
            g = tetrad_grasp(path_to_tetrad, data_path, parameter)
        case other:
            raise ValueError("no method implemented for algorithm " + other)
    g.write()
    tetrad_cleanup(f"{algorithm}_out.txt")
