import math
import csv


colormap = {
    "true": "tblweirdgreen",
    "flop-restarts=0-lambda=2.0-randomstart=False": "tblpink",
    "flop-restarts=0-lambda=2.0-randomstart=False-perturbations=1.0": "tblpink",
    "flop-restarts=20-lambda=2.0-randomstart=False": "tblred",
    "flop-restarts=100-lambda=2.0-randomstart=False": "tblflamingo",
    "flop-restarts=100-lambda=2.0-randomstart=False-perturbations=1.0": "tblflamingo",
    "flop-restarts=500-lambda=2.0-randomstart=False": "tblrose",
    "flop-restarts=0-lambda=2.0-randomstart=True": "tblyellow",
    "flop_baseline_lazygs-restarts=0-lambda=2.0": "tblorange",
    "flop_baseline_naivegs-restarts=0-lambda=2.0": "tblapricot",
    "boss-restarts=0-lambda=2.0": "tblteal",
    "boss-restarts=20-lambda=2.0": "tblforest",
    "boss-restarts=100-lambda=2.0": "tblapple",
    "pc-alpha=0.01": "tblmustard",
    "ges-lambda=2.0": "tblblue",
    "dagma-lambda=0.02": "tblgray",
    "exact-lambda=2.0": "tblsky",
    "lingam": "tbllightgray",
    "dagma_nonlinear-lambda=0.02": "tblforest",
    "flop-restarts=20-lambda=0.5-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=1.0-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=1.5-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=2.5-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=3.0-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=3.5-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=4.0-randomstart=False-perturbations=1.0": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=0.25": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=0.5": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=0.75": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=1.3333333333333333": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=2.0": "tblred",
    "flop-restarts=20-lambda=2.0-randomstart=False-perturbations=4.0": "tblred",
    "flop-restarts=20-lambda=0.5-randomstart=False-perturbations=0.0": "tblred",
    "grasp-lambda=2.0": "tblred",
    "xges-lambda=2.0": "tblred",
}


class Point:
    x: float
    y: float

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __add__(self, other):
        return Point(self.x + other.x, self.y + other.y)

    def __str__(self):
        return f"({self.x}, {self.y})"

    def __repr__(self):
        return f"({self.x}, {self.y})"

    def scaled_by(self, z):
        return Point(self.x * z.x, self.y * z.y)


def mean(values):
    return sum(values) / len(values)


def stddev(values):
    m = mean(values)
    variance = sum((x - m) ** 2 for x in values) / (len(values) - 1)  # sample stddev
    return math.sqrt(variance)


def line_bar_plot(
    data_path, tikz_path, title, xcol, xlabel, ycol, ylabel, suppress_algos=("true",)
):
    data_points = {}
    with open(data_path, newline="", encoding="utf-8") as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if row["algo"] in suppress_algos:
                continue
            if row[ycol] != "":
                data_points.setdefault(row["algo"], []).append(
                    Point(float(row[xcol].split("-")[1]), float(row[ycol]))
                )

    pmin = Point(0.0, 0.0)
    pmax = Point(
        max(max([p.x for p in lst]) for lst in data_points.values() if lst),
        max(max([p.y for p in lst]) for lst in data_points.values() if lst),
    )
    scale_down = Point(20.0 / pmax.x, 7.5 / pmax.y)  # TODO: use scale_by

    with open(tikz_path, "w") as tikz_file:
        tikz_file.write(f"\\node at {Point(10.0, 8)} {{{title}}};")

        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {pmin + Point(0.0, 0.0)} -- {pmin + Point(21.0, 0.0)};\n"
        )
        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {pmin + Point(0.0, 0.0)} -- {pmin + Point(0.0, 7.5)};\n"
        )

        xstep = 50
        tick = 50
        while tick <= pmax.x:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(tick * scale_down.x, 0.0)} -- {Point(tick * scale_down.x, -0.2)} node[below,black] {{{tick}}};\n"
            )
            tick += xstep

        ystep = 10 ** math.floor(math.log10(pmax.y))
        if pmax.y / ystep < 3:
            ystep //= 2
        tick = ystep
        while tick < pmax.y:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(0.0, tick * scale_down.y)} -- {Point(-0.2, tick * scale_down.y)} node[left,black] {{{tick}}};\n"
            )
            tick += ystep

        tikz_file.write(f"\\node at {Point(10.0, -1.65)} {{{xlabel}}};\n")
        tikz_file.write(f"\\node at {Point(0.0, 8)} {{{ylabel}}};\n")

        for algo, points in data_points.items():
            grouped_points = {}
            for point in points:
                grouped_points.setdefault(point.x, []).append(point.y)

            last_point = Point(0.0, 0.0)
            for num_nodes, runtimes in sorted(grouped_points.items()):
                mn = mean(runtimes)
                sd = stddev(runtimes)
                tikz_file.write(
                    f"\\node[circle, inner sep = 0pt, minimum width = 1.5mm, fill={colormap[algo]}] at {str(Point(num_nodes, mn).scaled_by(scale_down))} {{}};\n"
                )
                tikz_file.write(
                    f"\\draw[thick,dashed,color={colormap[algo]}] {last_point.scaled_by(scale_down)} -- {str(Point(num_nodes, mn).scaled_by(scale_down))};\n"
                )
                tikz_file.write(
                    f"\\draw[semithick,color={colormap[algo]}] {str(Point(num_nodes, mn).scaled_by(scale_down))} -- {Point(num_nodes, mn + sd).scaled_by(scale_down)};\n"
                )
                eps = 3.0
                tikz_file.write(
                    f"\\draw[semithick,color={colormap[algo]}] {str(Point(num_nodes - eps, mn + sd).scaled_by(scale_down))} -- {Point(num_nodes + eps, mn + sd).scaled_by(scale_down)};\n"
                )
                tikz_file.write(
                    f"\\draw[semithick,color={colormap[algo]}] {str(Point(num_nodes, mn).scaled_by(scale_down))} -- {Point(num_nodes, mn - sd).scaled_by(scale_down)};\n"
                )
                tikz_file.write(
                    f"\\draw[semithick,color={colormap[algo]}] {str(Point(num_nodes - eps, mn - sd).scaled_by(scale_down))} -- {Point(num_nodes + eps, mn - sd).scaled_by(scale_down)};\n"
                )
                last_point = Point(num_nodes, mn)


def scatter_mean_plot(
    data_path,
    tikz_path,
    title,
    xcol,
    xlabel,
    ycol,
    ylabel,
    shift_ylabel=0.0,
    suppress_algos=("true",),
    filter_samples=None,
    filter_graph=None,
    filter_data_type=None,
    shift_by=None,
):
    data_points = {}
    with open(data_path, newline="", encoding="utf-8") as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if xcol == "lambda":
                row["lambda"] = float(row["algo"].split("-")[2].split("=")[1])
            if xcol == "perturbations":
                row["perturbations"] = float(row["algo"].split("-")[4].split("=")[1])
            if filter_samples is not None:
                if int(row["data"].split("-")[0]) not in filter_samples:
                    continue
            if filter_graph is not None:
                if row["graph"].split("-")[1] not in filter_graph:
                    continue
            if filter_data_type is not None:
                if row["data"].split("-")[1] not in filter_data_type:
                    continue
            if row[xcol] != "" and row[ycol] != "":
                data_points.setdefault(row["algo"], []).append(
                    Point(float(row[xcol]), float(row[ycol]))
                )
            else:
                data_points.setdefault(row["algo"], []).append(None)

    if shift_by is not None:
        data_shift = data_points[shift_by]
        for algo, points in data_points.items():
            data_points[algo] = list(
                # TODO: this may be wrong for PC in case data is missing
                map(
                    lambda x: (
                        Point(x[1].x - data_shift[x[0]].x, x[1].y)
                        if x[1] is not None
                        else None
                    ),
                    enumerate(points),
                ),
            )

    for algo in data_points:
        data_points[algo] = list(filter(None, data_points[algo]))
    data_points = {k: v for k, v in data_points.items() if v != []}

    for algo in suppress_algos:
        if algo in data_points:
            del data_points[algo]

    pmin = Point(0.0, 0.0)
    pmax = Point(
        max(max([p.x for p in lst]) for lst in data_points.values() if lst),
        max(max([p.y for p in lst]) for lst in data_points.values() if lst),
    )
    scale_down = Point(10.0 / pmax.x, 7.5 / pmax.y)  # TODO: use scale_by

    with open(tikz_path, "w") as tikz_file:
        tikz_file.write(f"\\node at {Point(5.0, 9)} {{{title}}};")

        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {pmin + Point(-0.5, -0.5)} -- {pmin + Point(10.5, -0.5)};\n"
        )
        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {pmin + Point(-0.5, -0.5)} -- {pmin + Point(-0.5, 7.5)};\n"
        )

        xstep = 10 ** math.floor(math.log10(pmax.x))
        if pmax.x / xstep < 3:
            xstep //= 2
        tick = 0
        while tick < pmax.x:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(tick * scale_down.x, -0.5)} -- {Point(tick * scale_down.x, -0.7)} node[below,black] {{{tick}}};\n"
            )
            tick += xstep

        ystep = 10 ** math.floor(math.log10(pmax.y))
        if pmax.y / ystep < 3:
            ystep //= 2
        tick = 0
        while tick < 0.95 * pmax.y:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(-0.5, tick * scale_down.y)} -- {Point(-0.7, tick * scale_down.y)} node[left,black] {{{tick}}};\n"
            )
            tick += ystep

        tikz_file.write(f"\\node[align=left] at {Point(5.0, -2)} {{{xlabel}}};\n")
        tikz_file.write(
            f"\\node[align=left] at {Point(-0.5 + shift_ylabel, 8)} {{{ylabel}}};\n"
        )

        for algo, points in data_points.items():
            for point in points:
                tikz_file.write(
                    f"\\fill[{colormap[algo]}, opacity=0.3] {str(point.scaled_by(scale_down))} circle (0.1cm);\n"
                )

        for algo, points in data_points.items():
            # TODO: compute mean and print diamond
            mean = sum(points, Point(0, 0)).scaled_by(
                Point(1 / len(points), 1 / len(points))
            )
            print(f"{algo} gets SHD {mean}")
            tikz_file.write(
                f"\\node[diamond, draw, fill={colormap[algo]}, inner sep = 0cm, minimum size=0.25cm] at {str(mean.scaled_by(scale_down))} {{}};\n"
            )

        if ycol == "shd":
            print(title)
            for algo, points in data_points.items():
                print(
                    f"ratio of true graph found for {algo} is {len([p for p in points if p.y == 0]) / len(points)}"
                )
        if xcol == "bic":
            print(title)
            for algo, points in data_points.items():
                print(
                    f"ratio of better or equal graph than {shift_by} found for {algo} is {len([p for p in points if p.x <= 0.0]) / len(points)}"
                )


def scatter_mean_bic_downwards_plot(
    data_path,
    tikz_path,
    title,
    xcol,
    xlabel,
    ycol,
    ylabel,
    shift_ylabel=0.0,
    suppress_algos=("true",),
    filter_samples=None,
    filter_graph=None,
    shift_by=None,
):
    data_points = {}
    with open(data_path, newline="", encoding="utf-8") as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if filter_samples is not None:
                if int(row["data"].split("-")[0]) not in filter_samples:
                    continue
            if filter_graph is not None:
                if row["graph"].split("-")[1] not in filter_graph:
                    continue
            if row[ycol] != "":
                if row["algo"] in suppress_algos and row["algo"] != "true":
                    continue
                data_points.setdefault(row["algo"], []).append(
                    Point(float(row[xcol]), float(row[ycol]))
                )
    if shift_by is not None:
        for algo, points in data_points.items():
            if algo in suppress_algos:
                continue
            data_points[algo] = list(
                map(
                    lambda x: Point(x[1].x, x[1].y - data_points[shift_by][x[0]].y),
                    enumerate(points),
                )
            )

    del data_points["true"]

    pmin = Point(
        0.0,
        min(min([p.y for p in lst]) for lst in data_points.values() if lst),
    )
    pmax = Point(
        max(max([p.x for p in lst]) for lst in data_points.values() if lst), 0.0
    )
    scale_down = Point(10.0 / pmax.x, 7.5 / (pmax.y - pmin.y))  # TODO: use scale_by

    with open(tikz_path, "w") as tikz_file:
        tikz_file.write(f"\\node at {Point(5.0, -9)} {{{title}}};")

        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {Point(-0.5, 0.5)} -- {Point(10.5, 0.5)};\n"
        )
        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {Point(-0.5, 0.5)} -- {Point(-0.5, -7.5)};\n"
        )

        xstep = 10 ** math.floor(math.log10(pmax.x))
        if pmax.x / xstep < 3:
            xstep //= 2
        tick = 0
        while tick < pmax.x:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(tick * scale_down.x, 0.5)} -- {Point(tick * scale_down.x, 0.7)} node[above,black] {{{tick}}};\n"
            )
            tick += xstep

        ystep = 10 ** math.floor(math.log10((pmax.y - pmin.y)))
        if -pmin.y / ystep < 3:
            ystep //= 2
        tick = 0
        while tick < -0.95 * pmin.y:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(-0.5, -tick * scale_down.y)} -- {Point(-0.7, -tick * scale_down.y)} node[left,black] {{{-tick}}};\n"
            )
            tick += ystep

        tikz_file.write(f"\\node at {Point(5.0, 2)} {{{xlabel}}};\n")
        tikz_file.write(f"\\node at {Point(-0.5 + shift_ylabel, -8.0)} {{{ylabel}}};\n")

        for algo, points in data_points.items():
            if algo in suppress_algos:
                continue
            for point in points:
                tikz_file.write(
                    f"\\fill[{colormap[algo]}, opacity=0.3] {str(point.scaled_by(scale_down))} circle (0.1cm);\n"
                )

        for algo, points in data_points.items():
            if algo in suppress_algos:
                continue
            # TODO: compute mean and print diamond
            mean = sum(points, Point(0, 0)).scaled_by(
                Point(1 / len(points), 1 / len(points))
            )
            tikz_file.write(
                f"\\node[diamond, draw, fill={colormap[algo]}, inner sep = 0cm, minimum size=0.25cm] at {str(mean.scaled_by(scale_down))} {{}};\n"
            )
    if ycol == "bic":
        print(title)
        for algo, points in data_points.items():
            print(
                f"ratio of better or equal graph than {shift_by} found for {algo} is {len([p for p in points if p.y <= 0.0]) / len(points)}"
            )


def scatter_mean_bic_upwards_plot(
    data_path,
    tikz_path,
    title,
    xcol,
    xlabel,
    ycol,
    ylabel,
    shift_ylabel=0.0,
    suppress_algos=("true",),
    filter_samples=None,
    filter_graph=None,
    shift_by=None,
):
    data_points = {}
    with open(data_path, newline="", encoding="utf-8") as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if filter_samples is not None:
                if int(row["data"].split("-")[0]) not in filter_samples:
                    continue
            if filter_graph is not None:
                if row["graph"].split("-")[1] not in filter_graph:
                    continue
            if row[ycol] != "":
                data_points.setdefault(row["algo"], []).append(
                    Point(float(row[xcol]), float(row[ycol]))
                )

    if shift_by is not None:
        for algo, points in data_points.items():
            if algo in suppress_algos:
                continue
            data_points[algo] = list(
                map(
                    lambda x: Point(x[1].x, x[1].y - data_points[shift_by][x[0]].y),
                    enumerate(points),
                )
            )

    del data_points["true"]

    pmin = Point(0.0, 0.0)
    pmax = Point(
        max(max([p.x for p in lst]) for lst in data_points.values() if lst),
        max(max([p.y for p in lst]) for lst in data_points.values() if lst),
    )
    scale_down = Point(10.0 / pmax.x, 7.5 / (pmax.y - pmin.y))  # TODO: use scale_by

    with open(tikz_path, "w") as tikz_file:
        tikz_file.write(f"\\node at {Point(5.0, 9)} {{{title}}};")

        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {Point(-0.5, -0.5)} -- {Point(10.5, -0.5)};\n"
        )
        tikz_file.write(
            f"\\draw[axisgray,semithick,->,>={{[round,sep]Stealth}}] {pmin + Point(-0.5, -0.5)} -- {pmin + Point(-0.5, 7.5)};\n"
        )

        xstep = 10 ** math.floor(math.log10(pmax.x))
        if pmax.x / xstep < 3:
            xstep //= 2
        tick = 0
        while tick < pmax.x:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(tick * scale_down.x, -0.5)} -- {Point(tick * scale_down.x, -0.7)} node[below,black] {{{tick}}};\n"
            )
            tick += xstep

        ystep = 10 ** math.floor(math.log10((pmax.y - pmin.y)))
        if pmax.y / ystep < 3:
            ystep //= 2
        tick = 0
        while tick < 0.95 * pmax.y:
            tikz_file.write(
                f"\\draw[axisgray,thin,font=\\small] {Point(-0.5, tick * scale_down.y)} -- {Point(-0.7, tick * scale_down.y)} node[left,black] {{{tick}}};\n"
            )
            tick += ystep

        tikz_file.write(f"\\node[align=left] at {Point(5.0, -2)} {{{xlabel}}};\n")
        tikz_file.write(
            f"\\node[align=left] at {Point(-0.5 + shift_ylabel, 8.0)} {{{ylabel}}};\n"
        )

        for algo, points in data_points.items():
            if algo in suppress_algos:
                continue
            for point in points:
                tikz_file.write(
                    f"\\fill[{colormap[algo]}, opacity=0.3] {str(point.scaled_by(scale_down))} circle (0.1cm);\n"
                )

        for algo, points in data_points.items():
            if algo in suppress_algos:
                continue
            # TODO: compute mean and print diamond
            mean = sum(points, Point(0, 0)).scaled_by(
                Point(1 / len(points), 1 / len(points))
            )
            tikz_file.write(
                f"\\node[diamond, draw, fill={colormap[algo]}, inner sep = 0cm, minimum size=0.25cm] at {str(mean.scaled_by(scale_down))} {{}};\n"
            )
    if ycol == "bic":
        print(title)
        for algo, points in data_points.items():
            print(
                f"ratio of better or equal graph than {shift_by} found for {algo} is {len([p for p in points if p.y <= 0.0]) / len(points)}"
            )


line_bar_plot(
    "results/large.csv",
    "../paper/img/runtime.tikz",
    "ER, average degree 16, 1000 samples",
    "graph",
    "Number of nodes",
    "runtime",
    "Time in $s$",
)

scatter_mean_plot(
    "results/default.csv",
    "../paper/img/default_shd.tikz",
    "ER, 50 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/default.csv",
    "../paper/img/default_aid.tikz",
    "ER, 50 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
)

scatter_mean_plot(
    "results/chain.csv",
    "../paper/img/chain_shd.tikz",
    "Path, 50 nodes",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/dense.csv",
    "../paper/img/dense_shd.tikz",
    "ER, 25 nodes, average degree 16",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_samples=(50_000,),
)

scatter_mean_plot(
    "results/sf.csv",
    "../paper/img/sf_shd.tikz",
    "SF, 50 nodes, density parameter 4",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/alarm_shd.tikz",
    "Alarm network",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_graph=("alarm",),
)

scatter_mean_plot(
    "results/uniform.csv",
    "../paper/img/uniform_shd.tikz",
    "Uniform noise",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/raw.csv",
    "../paper/img/raw_shd.tikz",
    "Unstandardized",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/onion.csv",
    "../paper/img/onion_shd.tikz",
    "DAG-adaptation of the Onion",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/dense.csv",
    "../paper/img/dense_1000_shd.tikz",
    "Dense ER graphs, 1000 samples",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_samples=(1_000,),
)

scatter_mean_plot(
    "results/large_accuracy.csv",
    "../paper/img/large_accuracy_250.tikz",
    "ER, 250 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_graph=("250",),
)

scatter_mean_plot(
    "results/large_accuracy.csv",
    "../paper/img/large_accuracy_500.tikz",
    "ER, 500 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_graph=("500",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/barley_shd.tikz",
    "Barley network",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_graph=("barley",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/mildew_shd.tikz",
    "Mildew network",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_graph=("mildew",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/pathfinder_shd.tikz",
    "Pathfinder network",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    filter_graph=("pathfinder",),
)

scatter_mean_bic_downwards_plot(
    "results/bnlearn.csv",
    "../paper/img/pathfinder_bic.tikz",
    "Pathfinder network",
    "runtime",
    "Runtime in seconds",
    "bic",
    "$\\text{BIC}_{\\text{true}} - \\text{BIC}_{\\text{true}}$",
    shift_ylabel=0.6,
    filter_graph=("pathfinder",),
    shift_by="true",
)

scatter_mean_plot(
    "results/sf.csv",
    "../paper/img/sf_aid.tikz",
    "SF, 50 nodes, density parameter 4",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
)

scatter_mean_plot(
    "results/onion.csv",
    "../paper/img/onion_aid.tikz",
    "DAG-adaptation of the Onion",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
)

scatter_mean_plot(
    "results/large_accuracy.csv",
    "../paper/img/large_accuracy_250_aid.tikz",
    "ER, 250 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
    filter_graph=("250",),
)

scatter_mean_plot(
    "results/large_accuracy.csv",
    "../paper/img/large_accuracy_500_aid.tikz",
    "ER, 500 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
    filter_graph=("500",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/alarm_aid.tikz",
    "Alarm network",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
    filter_graph=("barley",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/barley_aid.tikz",
    "Barley network",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
    filter_graph=("barley",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/mildew_aid.tikz",
    "Mildew network",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
    filter_graph=("mildew",),
)

scatter_mean_plot(
    "results/bnlearn.csv",
    "../paper/img/pathfinder_aid.tikz",
    "Pathfinder network",
    "runtime",
    "Runtime in seconds",
    "aid",
    "AID",
    filter_graph=("pathfinder",),
)

scatter_mean_plot(
    "results/causalAssembly.csv",
    "../paper/img/causalAssembly_shd.tikz",
    "\\texttt{causalAssembly}, 5000 samples",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
    suppress_algos=("true", "lingam", "dagma-lambda=0.02"),
)

scatter_mean_bic_downwards_plot(
    "results/causalAssembly.csv",
    "../paper/img/causalAssembly_bic.tikz",
    "\\texttt{causalAssembly}, 5000 samples",
    "runtime",
    "Runtime in seconds",
    "bic",
    "$\\text{BIC}_{\\text{true}} - \\text{BIC}_{\\text{algo}}$",
    shift_ylabel=0.6,
    shift_by="true",
    suppress_algos=(
        "true",
        "pc-alpha=0.01",
        "dagma-lambda=0.02",
        "lingam",
    ),  # , "ges-lambda=2.0"),
)

scatter_mean_bic_upwards_plot(
    "results/default.csv",
    "../paper/img/default_bic.tikz",
    "ER, 50 nodes, average degree 8",
    "runtime",
    "Runtime in seconds",
    "bic",
    "$\\text{BIC}_{\\text{true}} - \\text{BIC}_{\\text{algo}}$",
    shift_ylabel=0.6,
    shift_by="true",
    suppress_algos=(
        "true",
        "pc-alpha=0.01",
        "dagma-lambda=0.02",
    ),  # , "ges-lambda=2.0"),
)

scatter_mean_bic_upwards_plot(
    "results/sf.csv",
    "../paper/img/sf_bic.tikz",
    "SF, 50 nodes, density parameter 4",
    "runtime",
    "Runtime in seconds",
    "bic",
    "$\\text{BIC}_{\\text{true}} - \\text{BIC}_{\\text{algo}}$",
    shift_ylabel=0.6,
    shift_by="true",
    suppress_algos=(
        "true",
        "pc-alpha=0.01",
        "dagma-lambda=0.02",
    ),  # , "ges-lambda=2.0"),
)

# TODO
scatter_mean_plot(
    "results/nonlinear.csv",
    "../paper/img/nonlinear_mlp.tikz",
    "Non-linear (MLP)",
    "bic",
    "$\\text{BIC}_{\\text{opt}} - \\text{BIC}_{\\text{algo}}$",
    "shd",
    "SHD",
    filter_data_type=["mlp"],
    suppress_algos=[
        "exact-lambda=2.0",
        "flop-restarts=0-lambda=2.0-randomstart=False",
        "flop-restarts=20-lambda=2.0-randomstart=False",
    ],
    shift_by="exact-lambda=2.0",
)

scatter_mean_plot(
    "results/nonlinear.csv",
    "../paper/img/nonlinear_gp.tikz",
    "Non-linear (GP)",
    "bic",
    "$\\text{BIC}_{\\text{opt}} - \\text{BIC}_{\\text{algo}}$",
    "shd",
    "SHD",
    filter_data_type=["gp"],
    suppress_algos=[
        "exact-lambda=2.0",
        "flop-restarts=0-lambda=2.0-randomstart=False",
        "flop-restarts=20-lambda=2.0-randomstart=False",
    ],
    shift_by="exact-lambda=2.0",
)

scatter_mean_plot(
    "results/default_bic.csv",
    "../paper/img/default_lambda.tikz",
    "$\\text{FLOP}_{20}$ on ER, 50 nodes, average degree 8",
    "lambda",
    "$\\lambda_{\\text{BIC}}$",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/default_perturb.csv",
    "../paper/img/default_perturb.tikz",
    "$\\text{FLOP}_{20}$ on ER, 50 nodes, average degree 16",
    "perturbations",
    "$x \\cdot \\ln n$ random swaps",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/sachs.csv",
    "../paper/img/sachs.tikz",
    "Sachs network",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_plot(
    "results/further.csv",
    "../paper/img/further.tikz",
    "Further",
    "runtime",
    "Runtime in seconds",
    "shd",
    "SHD",
)

scatter_mean_bic_downwards_plot(
    "results/dense.csv",
    "../paper/img/dense_bic.tikz",
    "Dense ER graphs, 1000 samples",
    "runtime",
    "Runtime in seconds",
    "bic",
    "$\\text{BIC}_{\\text{true}} - \\text{BIC}_{\\text{algo}}$",
    shift_ylabel=0.6,
    shift_by="true",
    suppress_algos=(
        "true",
        "pc-alpha=0.01",
        "dagma-lambda=0.02",
        "lingam",
        "ges-lambda=2.0",
    ),  # , "ges-lambda=2.0"),
    filter_samples=(1_000,),
)
