import csv
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


def load(experiment_dir: Path, size: int, rescaled: bool) -> float:
    with open(experiment_dir / "evaluations.csv", newline="") as csvfile:
        reader = csv.reader(csvfile)
        evaluations = [
            (instance, gap, scaling_factor) for instance, gap, _, scaling_factor, *_ in reader
        ]

    evaluations = [
        (instance, gap, scaling_factor)
        for instance, gap, scaling_factor in evaluations
        if "TSPLIB" not in instance
    ]
    evaluations = [
        (
            int(instance.split("-")[1]) if instance != "tsp-test" else 100,
            float(gap) * 100,
            True if scaling_factor == "True" else False,
        )
        for instance, gap, scaling_factor in evaluations
    ]
    evaluations = [
        (instance, gap)
        for instance, gap, scaling_factor in evaluations
        if scaling_factor == rescaled
    ]
    evaluations = [gap for instance, gap in evaluations if instance == size]
    return evaluations[0]


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-o", "--output-file", type=Path, required=True, help="Where to save the plot")
    parser.add_argument("-nc", "--n-cities", type=int, default=10000, help="Size of the instances to consider")
    parser.add_argument("-fw", "--figure-width", type=int, default=6)
    parser.add_argument("-fh", "--figure-height", type=int, default=3)
    parser.add_argument("-s", "--sota", type=float, default=None, help="SotA dotted line")
    args = parser.parse_args()

    plt.style.use("./analysis/paper.mplstyle")
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, ax = plt.subplots(figsize=(args.figure_width, args.figure_height))

    baseline = load(Path("./outputs/CoordNS"), args.n_cities, rescaled=False)
    alibi = load(Path("./outputs/PENS-A"), args.n_cities, rescaled=False)
    alibi_rescaled = load(Path("./outputs/PENS-A"), args.n_cities, rescaled=True)

    names = ["CoordNS", "ALiBi", "ALiBi\n+ rescaling"]
    y = [baseline, alibi, alibi_rescaled]
    x = np.arange(len(y))
    ax.bar(x, y, color=colors[1])
    ax.axhline(y=args.sota, linestyle="dotted", color=colors[0], linewidth=2, label="Previous SOTA")

    ax.set_title(f"TSP-{args.n_cities if args.n_cities != 10000 else "10 000"}")
    ax.set_ylabel("Optimal Gap (%)")
    ax.set_xticks(x)
    ax.set_xticklabels(names)
    ax.tick_params(axis="x", length=0)

    ax.legend()

    fig.tight_layout()
    fig.savefig(str(args.output_file), bbox_inches="tight", pad_inches=0.02)
