from dataclasses import dataclass
import itertools
from pathlib import Path
import os
from typing import Tuple


@dataclass(frozen=True)
class EvalConfig:
    output_path: Path = Path("simulation.csv")
    repetitions: int = 50
    metrics: Tuple[str, ...] = ("shd", "aid", "bic", "dagma-loss")
    algos: Tuple[str, ...] = (
        "true",
        "flop",
        "boss",
        "pc",
        "ges",
        "dagma",
    )


@dataclass(frozen=True)
class AlgoConfig:
    restarts_flop: Tuple[int, ...] = (0, 20, 100)
    random_start_flop: Tuple[bool, ...] = (False,)
    perturbations_flop: Tuple[float, ...] = (1.0,)
    restarts_boss: Tuple[int, ...] = (0,)
    lambda_bic: Tuple[float, ...] = (2.0,)
    alpha_pc: float = 0.01
    lambda_dagma: float = 0.02
    timeout: float = 600.0


@dataclass(frozen=True)
class GraphConfig:
    graph_type: str = "er"
    num_nodes: Tuple[int, ...] = (50,)
    avg_degree: Tuple[int, ...] = (8,)


@dataclass(frozen=True)
class DataConfig:
    num_samples: Tuple[int, ...] = (1000,)
    noise: Tuple[str, ...] = ("gaussian",)
    correction: Tuple[str, ...] = ("standardized",)
    relations: Tuple[str, ...] = ("linear",)
    external: str = ""


@dataclass(frozen=True)
class Config:
    eval: EvalConfig = EvalConfig()
    algo: AlgoConfig = AlgoConfig()
    graph: GraphConfig = GraphConfig()
    data: DataConfig = DataConfig()

    def algo_ids(self):
        ids = []
        for algo_name in self.eval.algos:
            if algo_name == "flop":
                for restarts, random_start, perturbs, lbic in itertools.product(
                    self.algo.restarts_flop,
                    self.algo.random_start_flop,
                    self.algo.perturbations_flop,
                    self.algo.lambda_bic,
                ):
                    ids.append(
                        f"{algo_name}-restarts={restarts}-lambda={lbic}-randomstart={random_start}-perturbations={perturbs}"
                    )
            elif algo_name in ["flop_baseline_naivegs", "flop_baseline_lazygs"]:
                for restarts, lbic in itertools.product(
                    self.algo.restarts_flop, self.algo.lambda_bic
                ):
                    ids.append(f"{algo_name}-restarts={restarts}-lambda={lbic}")
            elif algo_name == "boss":
                for restarts, lbic in itertools.product(
                    self.algo.restarts_boss, self.algo.lambda_bic
                ):
                    ids.append(f"{algo_name}-restarts={restarts}-lambda={lbic}")
            elif algo_name in ["ges", "exact", "grasp"]:
                for lbic in self.algo.lambda_bic:
                    ids.append(f"{algo_name}-lambda={lbic}")
            elif algo_name == "pc":
                ids.append(f"{algo_name}-alpha={self.algo.alpha_pc}")
            elif algo_name == "dagma":
                ids.append(f"{algo_name}-lambda={self.algo.lambda_dagma}")
            elif algo_name == "dagma_nonlinear":
                ids.append(f"{algo_name}-lambda={self.algo.lambda_dagma}")
            elif algo_name == "lingam":
                ids.append("lingam")
            elif algo_name == "xges":
                for lbic in self.algo.lambda_bic:
                    ids.append(f"{algo_name}-lambda={lbic}")
            elif algo_name == "r2":
                ids.append("r2")  # not used, is not competitive
            elif algo_name == "golem":
                ids.append("golem")  # not used, is not competitive
            elif algo_name == "true":
                ids.append("true")
            else:
                raise ValueError(f"algo {algo_name} not supported")
        return ids

    def graph_ids(self):
        ids = []
        if self.graph.graph_type == "bnlearn":
            bnlearn_dir = "bnlearn/"
            for _, _, files in os.walk(bnlearn_dir):
                for file in sorted(files):
                    base_name, _ = os.path.splitext(file)
                    if base_name in ["alarm", "barley"]:
                        ids.append(f"bnlearn-{base_name}")
        elif self.graph.graph_type == "chain":
            for nn in self.graph.num_nodes:
                ids.append(f"chain-{nn}")
        elif self.graph.graph_type == "causalAssembly":
            ids.append("causalAssembly")
        elif self.graph.graph_type == "sachs":
            ids.append("sachs")
        else:
            for nn, ad in itertools.product(
                self.graph.num_nodes, self.graph.avg_degree
            ):
                ids.append(f"{self.graph.graph_type}-{nn}-{ad}")
        return ids

    def data_ids(self):
        ids = []
        if self.data.external == "causalAssembly":
            for ns in self.data.num_samples:
                ids.append(f"{ns}-{self.data.external}")
            return ids
        if self.data.external == "sachs":
            ids.append("sachs")
            return ids
        for ns, re, no, co in itertools.product(
            self.data.num_samples,
            self.data.relations,
            self.data.noise,
            self.data.correction,
        ):
            if re != "linear":
                ids.append(f"{ns}-{re}")
            else:
                ids.append(f"{ns}-{no}-{co}")
        return ids
