import io
from pathlib import Path
import subprocess
import time

import graph
from graph import CausalGraph
from config import Config


def read_params(tokens):
    return dict(pair.split("=", 1) for pair in tokens)


def get_cmd(algo, params, data_file: Path):
    if algo == "flop":
        cmd = [
            "./flop/target/release/flop",
            str(data_file),
            params["lambda"],
            params["perturbations"],
            "--restarts",
            params["restarts"],
        ]
        if params["randomstart"] == "True":
            cmd.append("--random-start")
    elif algo == "flop_baseline_naivegs":
        cmd = [
            "./flop-baselines/target/release/flop-baselines",
            str(data_file),
            params["lambda"],
            "--restarts",
            params["restarts"],
            "--naivegs",
        ]
    elif algo == "flop_baseline_lazygs":
        cmd = [
            "./flop-baselines/target/release/flop-baselines",
            str(data_file),
            params["lambda"],
            "--restarts",
            params["restarts"],
        ]
    elif algo == "exact":
        cmd = [
            "./exact/target/release/exact",
            str(data_file),
            params["lambda"],
        ]
    elif algo == "boss":
        cmd = [
            "uv",
            "run",
            "run_tetrad.py",
            "tetrad/causal-cmd-1.12.0-jar-with-dependencies.jar",
            str(data_file),
            "boss",
            params["lambda"],
            str(int(params["restarts"]) + 1),
        ]
    elif algo == "pc":
        cmd = [
            "uv",
            "run",
            "run_tetrad.py",
            "tetrad/causal-cmd-1.12.0-jar-with-dependencies.jar",
            str(data_file),
            "pc",
            params["alpha"],
        ]
    elif algo == "ges":
        cmd = [
            "uv",
            "run",
            "run_tetrad.py",
            "tetrad/causal-cmd-1.12.0-jar-with-dependencies.jar",
            str(data_file),
            "ges",
            params["lambda"],
        ]
    elif algo == "grasp":
        cmd = [
            "uv",
            "run",
            "run_tetrad.py",
            "tetrad/causal-cmd-1.12.0-jar-with-dependencies.jar",
            str(data_file),
            "grasp",
            params["lambda"],
        ]
    elif algo == "dagma":
        cmd = ["uv", "run", "run_dagma.py", str(data_file), params["lambda"]]
    elif algo == "dagma_nonlinear":
        cmd = ["uv", "run", "run_dagma_nonlinear.py", str(data_file), params["lambda"]]
    elif algo == "lingam":
        cmd = ["uv", "run", "run_lingam.py", str(data_file)]
    elif algo == "xges":
        cmd = ["uv", "run", "run_xges.py", str(data_file), params["lambda"]]
    elif algo == "r2":
        cmd = ["uv", "run", "run_r2.py", str(data_file)]
    elif algo == "golem":
        cmd = ["uv", "run", "run_golem.py", str(data_file)]
    elif algo == "true":
        cmd = ["true"]
    else:
        raise ValueError(f"error: algoritm {algo} undefined")
    return cmd


def run(config: Config, id: str, data_file: Path, true_dag: CausalGraph):
    tokens = id.split("-")
    algo = tokens[0]
    params = read_params(tokens[1:])
    cmd = get_cmd(algo, params, data_file)
    start_time = time.time()
    if algo != "true":
        if algo != "exact" or true_dag.num_nodes() <= 25:
            result = subprocess.run(
                ["timeout", str(config.algo.timeout), *cmd],
                capture_output=True,
                text=True,
            )
        else:
            return None, None, "timeout"
    else:
        return true_dag, None, 0.0
    total_time = time.time() - start_time

    if result.returncode == 0:
        pass
    elif result.returncode == 124:
        total_time = "timeout"
    else:
        print(result.stdout)
        print(result.stderr)
        total_time = "error"

    if total_time not in ["timeout", "error"]:
        output = io.StringIO(result.stdout)
        if algo == "dagma":
            loss = float(output.readline())
        else:
            loss = None
        out_graph = graph.read(output)
    else:
        out_graph = None
        loss = None
    return out_graph, loss, total_time
