import io
import itertools
import numpy.random as rand
from pathlib import Path

import algo
from config import Config, EvalConfig, AlgoConfig, GraphConfig, DataConfig
import eval
import graph
import data
import utils


def run(config: Config):
    algo_ids = config.algo_ids()
    graph_ids = config.graph_ids()
    data_ids = config.data_ids()

    with open(config.eval.output_path, "w") as output_file:
        print(
            "algo",
            "graph",
            "data",
            "rep",
            *config.eval.metrics,
            "runtime",
            sep=",",
            file=output_file,
        )

        timed_out = False

        buffer = io.StringIO()

        for algo_id, graph_id, data_id, rep in itertools.product(
            algo_ids, graph_ids, data_ids, range(config.eval.repetitions)
        ):
            if rep == 0:
                timed_out = False
                buffer = io.StringIO()

            if timed_out:
                continue

            instance_id = f"{graph_id}_{data_id}_{rep}"
            data_file = Path(f"instances/{instance_id}_data.csv")
            graph_file = Path(f"instances/{instance_id}_graph.txt")
            if not Path(data_file).exists():
                rng = rand.default_rng(seed=utils.seed_from_string(instance_id))
                g = graph.sample(graph_id, rng)
                g.write_to_file(graph_file)
                X = data.sample(data_id, g, rng)
                data.write_to_file(X, data_file)

            true_dag = graph.read_from_file(graph_file)
            X = data.read_from_file(data_file)

            out_graph, reported_score, runtime = algo.run(
                config, algo_id, data_file, true_dag
            )
            if out_graph is None:
                if runtime == "timeout":
                    timed_out = True
                    buffer = io.StringIO()
                else:
                    print(
                        algo_id,
                        graph_id,
                        data_id,
                        rep,
                        *eval.empty_metrics(config),
                        runtime,
                        sep=",",
                        file=buffer,
                    )
                continue
            else:
                metrics = eval.metrics(
                    config, algo_id, true_dag, out_graph, X, reported_score
                )
                print(
                    algo_id,
                    graph_id,
                    data_id,
                    rep,
                    *metrics,
                    runtime,
                    sep=",",
                    file=buffer,
                )
                print(
                    algo_id,
                    graph_id,
                    data_id,
                    rep,
                    *metrics,
                    runtime,
                    sep=",",
                )
            if rep == config.eval.repetitions - 1:
                output_file.write(buffer.getvalue())
                output_file.flush()


def er_golem():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/golem.csv"),
            algos=("golem",),
        ),
        graph=GraphConfig(num_nodes=(50,), avg_degree=(8,)),
        algo=AlgoConfig(timeout=1800.0),
    )
    run(config)


def er_small():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/tmpflop.csv"),
            algos=("flop",),  # , "exact", "flop", "dagma", "boss", "pc", "ges"),
            repetitions=50,
        ),
        graph=GraphConfig(num_nodes=(20,), avg_degree=(4,)),
        algo=AlgoConfig(timeout=1800.0),
        data=DataConfig(num_samples=(50, 100, 200, 400, 800, 1600)),
    )
    run(config)


def er_basic():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/default.csv"),
            algos=(
                "grasp",
                "boss",
                "true",
                "flop",
                "pc",
                "ges",
                "dagma",
                "lingam",  # have lingam as well in this setting
            ),
        ),
        graph=GraphConfig(num_nodes=(50,), avg_degree=(8,)),
        algo=AlgoConfig(timeout=1800.0),
    )
    run(config)


def uniform_noise():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/uniform.csv"),
            algos=(
                "true",
                "flop",
                "boss",
                "pc",
                "ges",
                "dagma",
                "lingam",  # have lingam as well in this setting
            ),
        ),
        data=DataConfig(noise=("uniform",)),
    )
    run(config)


def onion_data():
    config = Config(
        eval=EvalConfig(output_path=Path("results/onion.csv")),
        data=DataConfig(correction=("onion",)),
    )
    run(config)


def raw_data():
    config = Config(
        eval=EvalConfig(output_path=Path("results/raw.csv")),
        data=DataConfig(correction=("raw",)),
    )
    run(config)


def sf():
    config = Config(
        eval=EvalConfig(output_path=Path("results/sf.csv")),
        graph=GraphConfig(graph_type="sf"),
    )
    run(config)


def er_dense():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/dense.csv"),
            algos=("true", "flop", "boss", "exact", "pc", "ges", "dagma"),
        ),
        algo=AlgoConfig(
            timeout=1800.0,
            restarts_flop=(0, 20, 100, 500),
            restarts_boss=(0, 20, 100),  # 500 restarts not feasible for BOSS here
        ),
        graph=GraphConfig(num_nodes=(25,), avg_degree=(16,)),
        data=DataConfig(
            num_samples=(
                1_000,
                50_000,
            )
        ),
    )
    run(config)


def er_large():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/large.csv"),
            algos=(
                "flop",
                "flop_baseline_naivegs",
                "flop_baseline_lazygs",
                "boss",
                "true",
            ),
        ),
        algo=AlgoConfig(timeout=1800.0, restarts_flop=(0,), restarts_boss=(0,)),
        graph=GraphConfig(
            num_nodes=(50, 100, 150, 200, 250, 300, 350, 400, 450, 500),
            avg_degree=(16,),
        ),
        data=DataConfig(num_samples=(1_000,)),
    )
    run(config)


def bnlearn():
    config = Config(
        eval=EvalConfig(output_path=Path("results/bnlearn.csv")),
        graph=GraphConfig(graph_type="bnlearn"),
    )
    run(config)


def chain_graph():
    config = Config(
        algo=AlgoConfig(
            random_start_flop=(True, False), restarts_boss=(0,), restarts_flop=(0,)
        ),
        eval=EvalConfig(output_path=Path("results/chain.csv")),
        graph=GraphConfig(graph_type=("chain")),
    )
    run(config)


def causalAssembly():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/causalAssembly.csv"),
            algos=(
                # "true",
                # "flop",
                # "boss",
                # "pc",
                # "ges",
                # "lingam",
                "dagma",
            ),
        ),
        graph=GraphConfig(graph_type="causalAssembly"),
        data=DataConfig(external="causalAssembly", num_samples=(5000,)),
    )
    run(config)


def nonlinear():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/nonlinear.csv"),
            algos=(
                "dagma_nonlinear",
                "true",
                "flop",
                "exact",
                "boss",
                "pc",
                "ges",
                "dagma",
                "lingam",  # have lingam as well in this setting
            ),
        ),
        data=DataConfig(relations=("mlp", "gp")),
        graph=GraphConfig(num_nodes=(25,), avg_degree=(4,)),
    )
    run(config)


def er_bic():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/default_bic.csv"),
            algos=("flop",),
        ),
        graph=GraphConfig(num_nodes=(50,), avg_degree=(8,)),
        algo=AlgoConfig(
            timeout=1800.0, lambda_bic=(0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0)
        ),
    )
    run(config)


def er_perturbations():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/default_perturb.csv"),
            algos=("flop",),
        ),
        graph=GraphConfig(num_nodes=(50,), avg_degree=(16,)),
        algo=AlgoConfig(
            timeout=1800.0,
            perturbations_flop=(0.25, 0.5, 0.75, 1.0, 4 / 3, 2.0, 4.0),
            restarts_flop=(0, 20, 50, 100),
        ),
    )
    run(config)


def further_methods():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/further.csv"),
            algos=(
                "grasp",
                "xges",
                "lingam",
            ),
        ),
        graph=GraphConfig(num_nodes=(50,), avg_degree=(8,)),
        algo=AlgoConfig(timeout=1800.0),
    )
    run(config)


def sachs():
    config = Config(
        eval=EvalConfig(
            output_path=Path("results/sachs.csv"),
            algos=(
                "true",
                "exact",
                "flop",
                "boss",
                "pc",
                "ges",
                "dagma",
                "lingam",
            ),
        ),
        graph=GraphConfig(graph_type="sachs"),
        data=DataConfig(external="sachs"),
    )
    run(config)


if __name__ == "__main__":
    er_small()
    er_basic()
    uniform_noise()
    onion_data()
    raw_data()
    sf()
    chain_graph()
    er_dense()
    er_large()
    bnlearn()
    sachs()
    er_golem()
    causalAssembly()
    nonlinear()
    er_perturbations()
    further_methods()
    er_bic()
