import csv
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


def load(experiment_dir: Path) -> dict[str, np.ndarray]:
    assert (experiment_dir / "history.csv").exists(), "History CSV file not found"

    with open(experiment_dir / "history.csv", "r", newline="") as csvfile:
        reader = iter(csv.reader(csvfile))
        columns = next(reader)
        data = {column: [] for column in columns}
        for row in reader:
            for col, value in zip(columns, row):
                data[col].append(float(value))

    for col in columns:
        data[col] = np.array(data[col])

    return data


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-o",  "--output-file", type=Path, help="Output file image")
    parser.add_argument("-n", "--names", type=str, nargs="+", help="Names of the experiments")
    parser.add_argument("experiment_dirs", type=Path, nargs="+", help="Experiment directories containing 'history.csv' files")
    args = parser.parse_args()

    plt.style.use("./analysis/paper.mplstyle")
    fig, axes = plt.subplots(1, 2, figsize=(4.5, 2))
    data = {exp_dir.name: load(exp_dir) for exp_dir in args.experiment_dirs}

    for name, history in zip(args.names, data.values()):
        axes[0].plot(history["step"][::2], history["tsp-20"][::2], label=name)
        axes[1].plot(history["step"][::2], history["tsp-100"][::2], label=name)

    axes[0].set_title("TSP-20")
    axes[1].set_title("TSP-100")

    axes[0].set_xlabel("Training Step")
    axes[1].set_xlabel("Training Step")
    axes[0].set_ylabel("Optimal Gap")

    axes[0].set_yscale("log")
    axes[1].set_yscale("log")

    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[1].set_xticks([])
    axes[1].set_yticks([])

    axes[1].legend()
    axes[0].set_ylim(0.01, 2)
    axes[1].set_ylim(0.5, 20)

    fig.tight_layout()
    fig.savefig(str(args.output_file))
