import json
from pathlib import Path
from typing import Tuple, Any, Optional

import jax.numpy as jnp
import numpy as np

import lsci.conformal as lsci


# ----------------------------
# Data loading helpers
# ----------------------------
def _first_existing(paths: Tuple[str, ...]) -> Path:
    for p in paths:
        if p and Path(p).exists():
            return Path(p)
    raise FileNotFoundError(f"None of the provided paths exist: {paths}")


def load_residual_npz(npz_path: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Load covariates and residuals from a single .npz file with keys:
      xval, rval, xtest, rtest
    """
    npz = np.load(npz_path)
    required = {"xval", "rval", "xtest", "rtest"}
    if required.issubset(npz.files):
        return (
            np.asarray(npz["xval"], dtype=np.float32),
            np.asarray(npz["rval"], dtype=np.float32),
            np.asarray(npz["xtest"], dtype=np.float32),
            np.asarray(npz["rtest"], dtype=np.float32),
        )
    missing = required - set(npz.files)
    raise KeyError(f"Missing keys {missing} in {npz_path}. Expected keys: {sorted(required)}")


def load_residual_two_files(calib_npz: Path, test_npz: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Load from two .npz files:
      calib must have xval, rval
      test  must have xtest, rtest
    """
    npz_c = np.load(calib_npz)
    npz_t = np.load(test_npz)
    if not {"xval", "rval"}.issubset(npz_c.files):
        raise KeyError(f"Calibration file {calib_npz} must contain xval, rval")
    if not {"xtest", "rtest"}.issubset(npz_t.files):
        raise KeyError(f"Test file {test_npz} must contain xtest, rtest")
    return (
        np.asarray(npz_c["xval"], dtype=np.float32),
        np.asarray(npz_c["rval"], dtype=np.float32),
        np.asarray(npz_t["xtest"], dtype=np.float32),
        np.asarray(npz_t["rtest"], dtype=np.float32),
    )


# ----------------------------
# Programmatic entrypoint
# ----------------------------
def run_lsci(
        *,
        # Data sources (exactly one of npz_all or (calib_npz & test_npz) must be provided)
        npz_all: Optional[str | Path] = None,
        calib_npz: Optional[str | Path] = None,
        test_npz: Optional[str | Path] = None,
        # Method parameters
        lam: int = 5,
        n_proj: int = 5,
        alpha: float = 0.1,
        # Output controls
        out_dir: str | Path = "outputs/lsci_heat_demo",
        save_weights: bool = False,
        suffix: str = "",
        # Logging
        verbose: bool = False,
) -> dict[str, Any]:
    """
    Run the LSCI pipeline (README-equivalent) programmatically.

    Parameters
    ----------
    npz_all : str | Path | None
        Path to a single .npz with keys xval, rval, xtest, rtest.
    calib_npz, test_npz : str | Path | None
        Paths to calibration and test .npz files (xval/rval) and (xtest/rtest).
    lam : int
        Bandwidth parameter for localization.
    n_proj : int
        Random projections for phi-depth.
    alpha : float
        Miscoverage level.
    out_dir : str | Path
        Directory to save artifacts (npz/json). Created if missing.
    save_weights : bool
        If True, saves local weights (can be large).
    suffix : str
        Suffix appended to output filenames.
    verbose : bool
        If True, prints a short summary.

    Returns
    -------
    dict
        {
          "coverage": float,
          "quant_val": np.ndarray (N_calib,),
          "local_weights": np.ndarray | None (N_test, N_calib),
          "paths": {
              "meta": Path,
              "calib_npz": Path,
              "test_npz": Path,
              "weights_npz": Path | None,
          },
          "meta": dict (same as saved JSON),
          "shapes": {
              "n_val": int,
              "n_test": int,
          }
        }
    """
    # Validate input choice
    have_single = npz_all is not None
    have_pair = (calib_npz is not None) and (test_npz is not None)
    if have_single == have_pair:
        raise ValueError("Provide exactly one of: npz_all OR (calib_npz AND test_npz).")

    # Load data
    if have_single:
        npz_all = _first_existing((str(npz_all),))
        xval, rval, xtest, rtest = load_residual_npz(npz_all)
        source_desc = {"npz_all": str(npz_all)}
    else:
        calib_npz = _first_existing((str(calib_npz),))
        test_npz = _first_existing((str(test_npz),))
        xval, rval, xtest, rtest = load_residual_two_files(calib_npz, test_npz)
        source_desc = {"calib_npz": str(calib_npz), "test_npz": str(test_npz)}

    # Convert to JAX arrays
    xval_j = jnp.asarray(xval)
    xtest_j = jnp.asarray(xtest)
    rval_j = jnp.asarray(rval)
    rtest_j = jnp.asarray(rtest)

    # Pipeline:
    # 1) localize
    # 2) depth on val and test
    # 3) quantiles from val depth
    # 4) coverage on test
    local_weights = lsci.localize(xval_j, xtest_j, lam)
    depth_val = lsci.local_phi_depth(rval_j, rval_j, local_weights, n_proj)
    depth_test = lsci.local_phi_depth(rval_j, rtest_j, local_weights, n_proj, reduce=False)
    quant_val = lsci.local_quantile(depth_val, alpha)
    coverage = jnp.mean(depth_test > quant_val)

    # Save artifacts
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    meta = {
        "lam": int(lam),
        "n_proj": int(n_proj),
        "alpha": float(alpha),
        "coverage": float(coverage),
        "n_val": int(xval.shape[0]),
        "n_test": int(xtest.shape[0]),
        "source": source_desc,
    }
    meta_path = out_dir / f"meta{suffix}.json"
    meta_path.write_text(json.dumps(meta, indent=2))

    calib_out = out_dir / f"calib_data{suffix}.npz"
    test_out = out_dir / f"test_data{suffix}.npz"
    np.savez_compressed(
        calib_out,
        xval=xval.astype(np.float32),
        rval=rval.astype(np.float32),
        quant_val=np.asarray(quant_val, dtype=np.float32),
    )
    np.savez_compressed(
        test_out,
        xtest=xtest.astype(np.float32),
        rtest=rtest.astype(np.float32),
    )

    weights_out = None
    if save_weights:
        weights_out = out_dir / f"local_weights{suffix}.npz"
        np.savez_compressed(
            weights_out,
            local_weights_val=np.array(local_weights, dtype=np.float32),
        )

    if verbose:
        print("LSCI (README-style) Results")
        print(f"- coverage:   {float(coverage):.3f}")
        print(f"- saved to:   {out_dir}")

    return {
        "coverage": float(coverage),
        "quant_val": np.asarray(quant_val, dtype=np.float32),
        "local_weights": (np.array(local_weights, dtype=np.float32) if save_weights else None),
        "paths": {
            "meta": meta_path,
            "calib_npz": calib_out,
            "test_npz": test_out,
            "weights_npz": weights_out,
        },
        "meta": meta,
        "shapes": {"n_val": int(xval.shape[0]), "n_test": int(xtest.shape[0])},
    }


# ----------------------------
# Optional CLI wrapper (kept for backward-compatibility)
# ----------------------------
def _cli_main() -> None:
    import argparse

    parser = argparse.ArgumentParser(
        description="Run LSCI exactly as in README on provided heat-equation residuals."
    )
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--npz_all", type=str, help="Path to a single NPZ with xval, rval, xtest, rtest")
    group.add_argument(
        "--npz_pair",
        nargs=2,
        metavar=("CALIB_NPZ", "TEST_NPZ"),
        help="Two NPZ files: calib (xval,rval) and test (xtest,rtest)",
    )
    parser.add_argument("--lam", type=int, default=5, help="Parameter for Bandwidth")
    parser.add_argument("--n_proj", type=int, default=5, help="Random projections for phi-depth")
    parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage level")
    parser.add_argument("--out_dir", type=str, default="outputs/lsci_heat_demo", help="Directory to save artifacts")
    parser.add_argument("--save_weights", action="store_true", help="Save local weights (large)")
    parser.add_argument("--suffix", type=str, default="", help="Suffix to append to out_dir")
    args = parser.parse_args()

    if args.npz_all:
        run_lsci(
            npz_all=args.npz_all,
            lam=args.lam,
            n_proj=args.n_proj,
            alpha=args.alpha,
            out_dir=args.out_dir,
            save_weights=args.save_weights,
            suffix=args.suffix,
            verbose=True,
        )
    else:
        calib_npz, test_npz = args.npz_pair
        run_lsci(
            calib_npz=calib_npz,
            test_npz=test_npz,
            lam=args.lam,
            n_proj=args.n_proj,
            alpha=args.alpha,
            out_dir=args.out_dir,
            save_weights=args.save_weights,
            suffix=args.suffix,
            verbose=True,
        )


if __name__ == "__main__":
    _cli_main()


# ----------------------------
# Plotting helper (unchanged API)
# ----------------------------
def plot_sample_with_conformal_bands(
        out_dir: str,
        sample_index: int,
        *,
        suffix: str = "",
        alpha: float = 0.1,
        n_proj: int = 5,
        n_samp: int = 2000,
        x_axis: np.ndarray | None = None,
        title: str | None = None,
        save_path: str | None = None,
        show: bool = True,
        figsize: tuple[float, float] = (8.0, 4.0),
        band_alpha: float = 0.25,
        truth_color: str = "black",
        band_color: str = "tab:blue",
        outline: bool = True,
) -> tuple[Any, Any]:
    """
    Plot a single test sample's ground truth trajectory with conformal bands.

    Inputs are artifacts produced by run_lsci():
      - {out_dir}/test_data{suffix}.npz with xtest, rtest
      - {out_dir}/calib_data{suffix}.npz with rval
      - {out_dir}/local_weights{suffix}.npz with local_weights_val  (requires save_weights=True)
    """
    import matplotlib.pyplot as plt  # local import to avoid hard dependency at module import time

    out = Path(out_dir)
    test_npz = np.load(out / f"test_data{suffix}.npz")
    if not {"xtest", "rtest"}.issubset(test_npz.files):
        raise KeyError(f"test_data.npz must contain xtest and rtest. Found keys: {sorted(test_npz.files)}")

    xtest = np.asarray(test_npz["xtest"], dtype=np.float32)
    rtest = np.asarray(test_npz["rtest"], dtype=np.float32)
    n_test, spatial_dim = xtest.shape
    if sample_index < 0 or sample_index >= n_test:
        raise IndexError(f"sample_index {sample_index} is out of range [0, {n_test})")

    calib_npz = np.load(out / f"calib_data{suffix}.npz")
    if not {"rval"}.issubset(calib_npz.files):
        raise KeyError(f"calib_data.npz must contain rval. Found keys: {sorted(calib_npz.files)}")
    rval = np.asarray(calib_npz["rval"], dtype=np.float32)  # (N_calib, X)

    weights_file = out / f"local_weights{suffix}.npz"
    if not weights_file.exists():
        raise FileNotFoundError(
            f"{weights_file} not found. Re-run run_lsci with save_weights=True to save local weights."
        )
    weights_npz = np.load(weights_file)
    if not {"local_weights_val"}.issubset(weights_npz.files):
        raise KeyError(
            f"local_weights.npz must contain 'local_weights_val'. Found keys: {sorted(weights_npz.files)}"
        )
    local_weights = np.asarray(weights_npz["local_weights_val"], dtype=np.float32)  # (N_test, N_calib)
    if local_weights.shape[0] != n_test:
        raise ValueError(f"local_weights first dim {local_weights.shape[0]} must equal n_test {n_test}")

    local_weights = jnp.asarray(local_weights)
    rval = jnp.asarray(rval)

    residual_samples = lsci.local_sampler(rval, local_weights[sample_index], float(alpha), int(n_samp), int(n_proj))

    base = xtest[sample_index]  # u(0)
    truth = base + rtest[sample_index]  # u(h)
    lower = base + residual_samples.min(axis=0)
    upper = base + residual_samples.max(axis=0)

    x_vals = x_axis if x_axis is not None else np.arange(spatial_dim)
    if x_vals.shape[0] != spatial_dim:
        raise ValueError(f"x_axis length {x_vals.shape[0]} must equal spatial dimension {spatial_dim}")

    fig, ax = plt.subplots(figsize=figsize)
    ax.fill_between(x_vals, lower, upper, color=band_color, alpha=band_alpha, label="Conformal band")
    if outline:
        ax.plot(x_vals, lower, color=band_color, linewidth=1.0, alpha=0.9)
        ax.plot(x_vals, upper, color=band_color, linewidth=1.0, alpha=0.9)
    ax.plot(x_vals, truth, color=truth_color, linewidth=2.0, label="Ground truth")

    ax.set_xlabel("x")
    ax.set_ylabel("u(x)")
    if title:
        ax.set_title(title)
    ax.legend()

    if save_path is not None:
        sp = Path(save_path)
        sp.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(sp, bbox_inches="tight")
    if show:
        plt.show()
    return fig, ax
