"""Airfoil plotting utilities for the Unifoil (fully turbulent) dataset.

Reconstructs airfoil geometry from 14 modal coefficients using the POD basis,
and provides functions to plot one or more airfoils.
"""

import numpy as np
import matplotlib.pyplot as plt

from uq_diagcfm.paths import DATA_DIR

BASIS_FILE = DATA_DIR / "unifoil_geometry" / "basis.txt"

# Loaded lazily on first use
_basis_cache = None


def _load_basis():
    global _basis_cache
    if _basis_cache is None:
        basis = np.loadtxt(BASIS_FILE)
        xslice = basis[0, :]  # (281,)
        modes = basis[1:, :]  # (14, 281)
        _basis_cache = (xslice, modes)
    return _basis_cache


def coeffs_to_coords(coeffs):
    """Convert modal coefficients to (x, y) airfoil coordinates.

    Args:
        coeffs: array of shape (14,) or (N, 14) — modal coefficients.

    Returns:
        If 1-D input: tuple (x, y) each of shape (281,).
        If 2-D input: tuple (x, y) where x has shape (281,) and y has shape (N, 281).
    """
    xslice, modes = _load_basis()
    coeffs = np.asarray(coeffs)
    if coeffs.ndim == 1:
        y = coeffs @ modes  # (281,)
    else:
        y = coeffs @ modes  # (N, 281)
    return xslice, y


def plot_airfoil(coeffs, ax=None, label=None, color=None, linewidth=1.5, **kwargs):
    """Plot a single airfoil from its 14 modal coefficients.

    Args:
        coeffs: array of shape (14,) — modal coefficients.
        ax: matplotlib Axes to plot on. Created if None.
        label: optional legend label.
        color: optional line color.
        linewidth: line width.
        **kwargs: forwarded to ax.plot().

    Returns:
        The matplotlib Axes.
    """
    if ax is None:
        _, ax = plt.subplots(figsize=(8, 3))
    x, y = coeffs_to_coords(np.asarray(coeffs).ravel())
    ax.plot(x, y, color=color, linewidth=linewidth, label=label, **kwargs)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("x/c")
    ax.set_ylabel("y/c")
    return ax


def plot_airfoils(coeffs_list, labels=None, colors=None, title=None, ax=None, **kwargs):
    """Plot multiple airfoils overlaid on the same axes.

    Args:
        coeffs_list: iterable of arrays, each of shape (14,).
        labels: optional list of legend labels (same length as coeffs_list).
        colors: optional list of colors (same length as coeffs_list).
        title: optional plot title.
        ax: matplotlib Axes to plot on. Created if None.
        **kwargs: forwarded to ax.plot().

    Returns:
        The matplotlib Axes.
    """
    if ax is None:
        _, ax = plt.subplots(figsize=(8, 3))
    for i, coeffs in enumerate(coeffs_list):
        label = labels[i] if labels is not None else None
        color = colors[i] if colors is not None else None
        plot_airfoil(coeffs, ax=ax, label=label, color=color, **kwargs)
    if title is not None:
        ax.set_title(title)
    if labels is not None:
        ax.legend()
    return ax


def plot_airfoil_grid(
    coeffs_array,
    ncols=5,
    color="C0",
    linewidth=1.0,
    figscale=1.0,
):
    """Plot a grid of airfoils without axes, suitable for paper figures.

    Args:
        coeffs_array: array of shape (N, 14) — modal coefficients for N airfoils.
        ncols: number of columns in the grid.
        color: line color for all airfoils.
        linewidth: line width.
        figscale: scaling factor for figure size.

    Returns:
        The matplotlib Figure.
    """
    coeffs_array = np.asarray(coeffs_array)
    n = len(coeffs_array)
    nrows = int(np.ceil(n / ncols))

    # Precompute all coordinates to find global y-limits
    xslice, all_y = coeffs_to_coords(coeffs_array)  # all_y: (N, 281)
    y_min, y_max = all_y.min(), all_y.max()
    y_pad = (y_max - y_min) * 0.1

    cell_w = 2.0 * figscale
    cell_h = 0.9 * figscale
    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(cell_w * ncols, cell_h * nrows),
        squeeze=False,
    )

    for idx, ax in enumerate(axes.flat):
        if idx < n:
            ax.plot(xslice, all_y[idx], color=color, linewidth=linewidth)
            ax.set_xlim(-0.02, 1.02)
            ax.set_ylim(y_min - y_pad, y_max + y_pad)
            ax.set_aspect("equal", adjustable="box")
        ax.axis("off")

    fig.subplots_adjust(wspace=0.05, hspace=0.05)
    return fig


if __name__ == "__main__":
    from uq_diagcfm.data_utils_unifoil import TRAIN_DATAFILE

    data = np.loadtxt(TRAIN_DATAFILE, max_rows=20, dtype=np.float32)
    fig = plot_airfoil_grid(data[:, :14], ncols=5)
    plt.savefig("test_grid.png", dpi=150, bbox_inches="tight")
    plt.show()
