"""Generate airfoil grid plots for the paper.

Supports four modes:
  --mode train          Airfoils from the training set
  --mode val            Airfoils from the validation set
  --mode gen-train      Generated from training set labels/physical params
  --mode gen-val        Generated from validation set labels/physical params

Usage:
    python -m uq_diagcfm.generate_and_plot_unifoil --mode train --rows 4 --cols 5
    python -m uq_diagcfm.generate_and_plot_unifoil --mode gen-val --rows 4 --cols 5
    python -m uq_diagcfm.generate_and_plot_unifoil --mode all --rows 4 --cols 5
"""

import argparse

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch

from uq_diagcfm.data_utils_unifoil import (
    LEN_DESIGN_PARAMETERS,
    LEN_PHYSICAL_PERFORMANCE,
    UnifoilDataset,
    split_unifoil_data,
)
from uq_diagcfm.paths import PAPER_FIGURES_DIR, ensure_paper_dirs_exist
from uq_diagcfm.plot_airfoil import plot_airfoil_grid


def _sample_dataset_airfoils(split, n, seed):
    """Return (design_params, physical_params, performance, indices) from a dataset split."""
    data = UnifoilDataset(split).data
    rng = np.random.default_rng(seed)
    indices = rng.choice(len(data), size=n, replace=False)
    samples = data[indices]
    design_params, physical_params, performance = split_unifoil_data(samples)
    return design_params, physical_params, performance, indices


def _generate_airfoils(physical_params, performance, device):
    """Run inverse pass to generate designs from labels + physical params."""
    from uq_diagcfm.ensembles import load_unifoil_diag_cfm_ensemble
    from uq_diagcfm.evaluation_utils import inverse_pass
    from uq_diagcfm.utils import get_device

    if device is None:
        device = get_device()

    models, _, _, _ = load_unifoil_diag_cfm_ensemble(device=device)
    model = models[0]

    y = torch.tensor(performance, dtype=torch.float32, device=device)
    cond = torch.tensor(physical_params, dtype=torch.float32, device=device)

    with torch.no_grad():
        x_gen = inverse_pass(
            model,
            y,
            num_design_params=LEN_DESIGN_PARAMETERS,
            num_labels=LEN_PHYSICAL_PERFORMANCE,
            device=device,
            num_samples=1,
            diag_cfm=True,
            conditioning=cond,
        )
    return x_gen.squeeze(0).cpu().numpy()


def run(mode, rows, cols, seed):
    from uq_diagcfm.utils import get_device

    n = rows * cols
    ensure_paper_dirs_exist()

    if mode in ("train", "val"):
        design_params, _, _, _ = _sample_dataset_airfoils(
            "train" if mode == "train" else "val", n, seed,
        )
        fig = plot_airfoil_grid(design_params, ncols=cols)
        out = PAPER_FIGURES_DIR / f"airfoil_grid_{mode}.pdf"
        fig.savefig(out, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved {out}")

    elif mode in ("gen-train", "gen-val"):
        split = "train" if mode == "gen-train" else "val"
        _, phys, perf, _ = _sample_dataset_airfoils(split, n, seed)
        device = get_device()
        gen_designs = _generate_airfoils(phys, perf, device)
        fig = plot_airfoil_grid(gen_designs, ncols=cols)
        out = PAPER_FIGURES_DIR / f"airfoil_grid_{mode.replace('-', '_')}.pdf"
        fig.savefig(out, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved {out}")

    elif mode == "all":
        for m in ("train", "val", "gen-train", "gen-val"):
            run(m, rows, cols, seed)

    else:
        raise ValueError(f"Unknown mode: {mode}")


def main():
    parser = argparse.ArgumentParser(description="Generate airfoil grid plots.")
    parser.add_argument(
        "--mode",
        choices=["train", "val", "gen-train", "gen-val", "all"],
        default="all",
        help="Which grid(s) to produce.",
    )
    parser.add_argument("--rows", type=int, default=4, help="Grid rows (default: 4).")
    parser.add_argument("--cols", type=int, default=5, help="Grid columns (default: 5).")
    parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42).")
    args = parser.parse_args()
    run(args.mode, args.rows, args.cols, args.seed)


if __name__ == "__main__":
    main()
