import csv
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, PolynomialFeatures

Size = int
ScalingFactor = float
Cost = float
CSVData = list[tuple[Size, ScalingFactor, Cost]]


def fit_polynome(factors: list[ScalingFactor], costs: list[Cost]) -> Pipeline:
    X = np.array(factors).reshape(-1, 1)
    y = np.array(costs)

    pipeline = Pipeline(
        [
            ("scaler", MinMaxScaler()),
            ("poly", PolynomialFeatures(degree=2)),
            ("reg", LinearRegression()),
        ]
    )
    return pipeline.fit(X, y)


def fit_scaling_law(factors: list[ScalingFactor], sizes: list[Size]) -> LinearRegression:
    X = np.array(sizes).reshape(-1, 1)
    y = np.array(factors)

    model = LinearRegression()
    return model.fit(np.log(X), np.log(y))


def plot_factors(ax: Axes, samples: CSVData, n_cities: Size) -> float:
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    factors = [s for n, s, _ in samples if n == n_cities]
    costs = [c for n, _, c in samples if n == n_cities]
    ax.scatter(factors, costs, color=colors[0], alpha=0.80, marker=".")

    # Estimate the best scaling factor.
    pipeline = fit_polynome(factors, costs)
    model = pipeline.named_steps["reg"]

    x = np.linspace(min(factors), max(factors), 100)
    ax.plot(x, pipeline.predict(x.reshape(-1, 1)), color=colors[1])

    s = -model.coef_[1] / (2 * model.coef_[2])
    x = pipeline.named_steps["scaler"].inverse_transform([[s]])
    y = model.predict([[1, s, s**2]])
    opt = x[0][0]
    ax.scatter(x, y, color=colors[1], label="best")
    ax.set_title(f"TSP-{n_cities}")

    return opt


def plot_scaling_law(ax: Axes, factors: list[ScalingFactor], sizes: list[Size]):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    ax.plot(sizes, factors, color=colors[0])
    ax.set_xlabel("TSP Size")
    ax.set_ylabel("Best Scaling Factor")


def load(experiment_dir: Path) -> CSVData:
    assert (experiment_dir / "random-solve.csv").exists(), "CSV file not found"

    with open(experiment_dir / "random-solve.csv", "r", newline="") as csvfile:
        data = [
            (int(n_cities), float(scaling_factor), float(cost))
            for n_cities, scaling_factor, cost in csv.reader(csvfile)
        ]

    return data


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Estimate the best scaling factors for a given model using its sweeps")
    parser.add_argument("-e", "--experiment-dir", required=True, type=Path, help="Path of the training run")
    parser.add_argument("-fw", "--figsize-width", type=float, default=4)
    parser.add_argument("-fh", "--figsize-height", type=float, default=4)
    parser.add_argument("cities", type=int, nargs="+", help="Cities to plot")
    args = parser.parse_args()

    plt.style.use("./analysis/paper.mplstyle")
    data = load(args.experiment_dir)

    fig, ax = plt.subplots()  # Dummy ax for computing optimal scaling factors.
    best_factors = []
    unique_cities = np.unique([n for n, _, _ in data])
    unique_cities.sort()

    for n_cities in unique_cities:
        best_factors.append(plot_factors(ax, data, n_cities))

    with open(args.experiment_dir / "best-factors.csv", "w", newline="") as csvfile:
        csv.writer(csvfile).writerows(
            [[n_cities, factor] for n_cities, factor in zip(unique_cities, best_factors)]
        )

    fig, axes = plt.subplots(1, len(args.cities), figsize=(args.figsize_width, args.figsize_height))

    for n_cities, ax in zip(args.cities, axes):
        plot_factors(ax, data, n_cities)

    axes[0].set_ylabel("Tour Length")

    for ax in axes:
        ax.set_xlabel("Scaling Factor")

    axes[0].legend(loc="upper left")

    for ax in axes.flatten():
        ax.set_yticks([])

    fig.tight_layout()
    fig.savefig(str(args.experiment_dir / "scaling-factors.pdf"))
