import csv
from pathlib import Path

import matplotlib.pyplot as plt


def load(experiment_dir: Path, take_stretched: bool) -> tuple[list[int], list[float]]:
    with open(experiment_dir / "evaluations.csv", newline="") as csvfile:
        reader = csv.reader(csvfile)
        evaluations = [
            (instance, gap, scaling_factor) for instance, gap, _, scaling_factor, *_ in reader
        ]

    evaluations = [
        (instance, gap, scaling_factor)
        for instance, gap, scaling_factor in evaluations
        if "TSPLIB" not in instance
    ]
    evaluations = [
        (
            int(instance.split("-")[1]) if instance != "tsp-test" else 100,
            float(gap) * 100,
            True if scaling_factor == "True" else False,
        )
        for instance, gap, scaling_factor in evaluations
    ]
    evaluations = [
        (instance, gap)
        for instance, gap, scaling_factor in evaluations
        if scaling_factor == take_stretched
    ]
    evaluations = [
        (instance, gap) for instance, gap in evaluations if instance in [100, 250, 500, 1000, 10000]
    ]
    evaluations = (
        [instance for instance, _ in evaluations],
        [gap for _, gap in evaluations],
    )
    return evaluations


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-o", "--output-file", type=Path, required=True, help="Where to save the plot")
    parser.add_argument("experiment_dirs", type=Path, nargs="+", help="Experiments to plot")
    args = parser.parse_args()

    plt.style.use("./analysis/paper.mplstyle")
    fig, ax = plt.subplots(figsize=(3, 2))
    markers = ["*", "+", "H", "s", "d"]

    for exp_dir, marker in zip(args.experiment_dirs, markers):
        instances, gaps = load(exp_dir, take_stretched=False)
        (line,) = ax.plot(instances, gaps, "--")
        ax.scatter(instances, gaps, marker=marker, color=line.get_color(), label=exp_dir.name)

    ax.set_xlabel("TSP Size")
    ax.set_ylabel("Optimal Gap (%)")

    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_yticks([1, 10], labels=["1", "10"])
    ax.set_xticks([100, 250, 500, 1000, 10000], labels=["100", "250", "500", "1 000", "10 000"])

    ax.legend()
    fig.tight_layout()
    fig.savefig(str(args.output_file))
