import csv
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


def load(experiment_dir: Path, take_stretched: bool) -> tuple[list[int], list[float]]:
    with open(experiment_dir / "evaluations.csv", newline="") as csvfile:
        reader = csv.reader(csvfile)
        evaluations = [
            (instance, gap, scaling_factor) for instance, gap, _, scaling_factor, *_ in reader
        ]

    evaluations = [
        (instance, gap, scaling_factor)
        for instance, gap, scaling_factor in evaluations
        if "TSPLIB" not in instance
    ]
    evaluations = [
        (
            int(instance.split("-")[1]) if instance != "tsp-test" else 100,
            float(gap) * 100,
            True if scaling_factor == "True" else False,
        )
        for instance, gap, scaling_factor in evaluations
    ]
    evaluations = [
        (instance, gap)
        for instance, gap, scaling_factor in evaluations
        if scaling_factor == take_stretched
    ]
    evaluations = [(instance, gap) for instance, gap in evaluations if instance in [1000, 10000]]
    evaluations = list(sorted(evaluations, key=lambda e: e[0]))
    evaluations = (
        [instance for instance, _ in evaluations],
        [gap for _, gap in evaluations],
    )
    return evaluations


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-o", "--output-file", type=Path, required=True, help="Where to save the plot")
    parser.add_argument("experiment_dirs", type=Path, nargs="+", help="Experiments to plot")
    args = parser.parse_args()

    plt.style.use("./analysis/paper.mplstyle")
    fig, ax = plt.subplots(figsize=(6, 3))

    bar_width = 0.25
    stretch_width = 0.05
    offsets = np.linspace(-bar_width, bar_width, len(args.experiment_dirs))
    benchmarks = ["1 000", "10 000"]
    x = np.arange(len(benchmarks))

    for exp_dir, offset in zip(args.experiment_dirs, offsets):
        instances, gaps = load(exp_dir, take_stretched=False)
        ax.bar(x + offset, gaps, width=bar_width, label=exp_dir.name)

        instances, gaps = load(exp_dir, take_stretched=True)
        ax.bar(x + offset, gaps, width=stretch_width, color="black")
        # ax.plot(instances, gaps, lines[1], color=line.get_color())
        # ax.scatter(instances, gaps, marker=marker, color=line.get_color())

    ax.set_xlabel("TSP Size")
    ax.set_ylabel("Optimal Gap (%)")

    # ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_xticks(x)
    ax.set_xticklabels(benchmarks)
    ax.tick_params(axis="x", length=0)

    ax.legend()

    fig.tight_layout()
    fig.savefig(str(args.output_file))
