import os
from pathlib import Path
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple

import numpy as np
import matplotlib.pyplot as plt


# -------------------------
# Ground-truth components
# -------------------------

def phi1_star(x): return 2.0 * np.sin(np.pi * x)
def phi2_star(x): return 0.5 * (x - 0.5) ** 2
def phi3_star(x): return 1.5 * np.maximum(0.0, x - 0.3)
def phi4_star(x): return 0.8 * (0.25 - (x - 0.5) ** 2)

def psi12_star(x1, x2):
    return 0.6 * np.sin(2 * np.pi * x1) * (x2 - 0.5)

def S_star(X: np.ndarray) -> np.ndarray:
    X = np.asarray(X, dtype=float)
    x1, x2, x3, x4 = X[:, 0], X[:, 1], X[:, 2], X[:, 3]
    return (
        phi1_star(x1)
        + phi2_star(x2)
        + phi3_star(x3)
        + phi4_star(x4)
        + psi12_star(x1, x2)
    )


# -------------------------
# Metrics
# -------------------------

def pairwise_rank_accuracy(s_true: np.ndarray, s_hat: np.ndarray) -> float:
    """Pairwise agreement of signs over all i<j. Ties count only when both tie."""
    s_true = np.asarray(s_true).reshape(-1)
    s_hat = np.asarray(s_hat).reshape(-1)
    n = s_true.size
    ok = 0
    tot = 0
    for i in range(n):
        for j in range(i + 1, n):
            a = np.sign(s_true[i] - s_true[j])
            b = np.sign(s_hat[i] - s_hat[j])
            if a == 0 and b == 0:
                ok += 1
                tot += 1
            elif a == 0 or b == 0:
                tot += 1
            else:
                ok += int(a == b)
                tot += 1
    return ok / max(tot, 1)


def evaluate_function_fit(
    model,
    gt_ft: Callable[[np.ndarray], np.ndarray],
    samples_L2: int = 30000,
    trials_rank: int = 3000,
    slate_N: int = 8,
    seed: int = 1,
) -> Tuple[float, float]:
    """
    Evaluate:
      - L2 RMSE of score function on random points
      - average pairwise ranking accuracy on random slates
    """
    rng = np.random.default_rng(seed)

    Xs = rng.uniform(0, 1, size=(samples_L2, model.K))
    s_true = gt_ft(Xs)
    s_hat, _ = model.forward(Xs)
    rmse = float(np.sqrt(np.mean((s_true - s_hat) ** 2)))

    accs = []
    for _ in range(trials_rank):
        X = rng.uniform(0, 1, size=(slate_N, model.K))
        st = gt_ft(X)
        sh, _ = model.forward(X)
        accs.append(pairwise_rank_accuracy(st, sh))

    return rmse, float(np.mean(accs))


def transfer_test(
    model,
    gt_ft: Callable[[np.ndarray], np.ndarray],
    Ns: Sequence[int] = (4, 8, 16, 32),
    trials: int = 1000,
    seed: int = 2,
) -> Dict[int, dict]:
    """
    For each slate size N:
      - oracle picks argmax s_true, learned picks argmax s_hat
      - report avg oracle reward, avg learned reward (under s_true), gap, and rank accuracy
    """
    rng = np.random.default_rng(seed)
    results: Dict[int, dict] = {}

    for N in Ns:
        rew_learn = []
        rew_oracle = []
        rank_acc = []

        for _ in range(trials):
            X = rng.uniform(0, 1, size=(N, model.K))
            s_true = gt_ft(X)
            s_hat, _ = model.forward(X)

            i_oracle = int(np.argmax(s_true))
            i_learn = int(np.argmax(s_hat))

            rew_oracle.append(float(s_true[i_oracle]))
            rew_learn.append(float(s_true[i_learn]))
            rank_acc.append(pairwise_rank_accuracy(s_true, s_hat))

        rew_oracle_m = float(np.mean(rew_oracle))
        rew_learn_m = float(np.mean(rew_learn))

        results[int(N)] = {
            "avg_reward_learn": rew_learn_m,
            "avg_reward_oracle": rew_oracle_m,
            "reward_gap": rew_oracle_m - rew_learn_m,
            "ranking_acc": float(np.mean(rank_acc)),
        }

    return results


# -------------------------
# Plotting helpers
# -------------------------

def _ensure_dir(path: Optional[str]) -> Optional[Path]:
    if path is None:
        return None
    p = Path(path)
    p.mkdir(parents=True, exist_ok=True)
    return p


def plot_components_1d(
    model,
    phi_true_fns: Optional[Sequence[Callable[[np.ndarray], np.ndarray]]] = None,
    num: int = 121,
    save_dir: Optional[str] = None,
    prefix: str = "components",
    show: bool = False,
):
    """
    Plot learned 1D components phi_k (centered), optionally compared against truth.
    If save_dir is provided, saves PNG files; otherwise returns figure handles.
    """
    out_dir = _ensure_dir(save_dir)
    grids, phis = model.components_1d(num=num)

    figs = []
    for k in range(model.K):
        x = grids[k]
        y_hat = phis[k]

        fig, ax = plt.subplots(figsize=(4, 3))
        if phi_true_fns is not None:
            y_true = phi_true_fns[k](x)
            y_true = y_true - np.mean(y_true)
            ax.plot(x, y_true, label="True")
            ax.plot(x, y_hat, "--", label="Estimated")
            ax.legend()
        else:
            ax.plot(x, y_hat, label="Estimated")
            ax.legend()

        ax.set_xlabel(f"x{k+1}")
        ax.set_ylabel(f"phi{k+1}")
        ax.set_title(f"1D component k={k+1}")
        fig.tight_layout()

        if out_dir is not None:
            fig.savefig(out_dir / f"{prefix}_phi{k+1}.png", dpi=200)
            plt.close(fig)
        else:
            figs.append(fig)
            if show:
                plt.show()

    return figs if out_dir is None else None


def plot_pair_slice(
    model,
    i: int,
    j: int,
    psi_true_fn: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None,
    xi_num: int = 121,
    xj_fixed_values: Sequence[float] = (0.25, 0.5, 0.75),
    save_dir: Optional[str] = None,
    prefix: str = "pair_slice",
    show: bool = False,
):
    """
    Plot slices of psi_{ij}(x_i, x_j_fixed) along x_i.
    """
    out_dir = _ensure_dir(save_dir)
    xi = np.linspace(0.0, 1.0, xi_num)

    figs = []
    for xj_fixed in xj_fixed_values:
        y_hat = model.component_2d_slice(i, j, xi, float(xj_fixed))

        fig, ax = plt.subplots(figsize=(4, 3))
        if psi_true_fn is not None:
            y_true = psi_true_fn(xi, float(xj_fixed) * np.ones_like(xi))
            y_true = y_true - np.mean(y_true)
            ax.plot(xi, y_true, label="True")
            ax.plot(xi, y_hat, "--", label="Estimated")
            ax.legend()
        else:
            ax.plot(xi, y_hat, label="Estimated")
            ax.legend()

        ax.set_xlabel(f"x{i+1}")
        ax.set_ylabel(f"psi{i+1}{j+1} slice")
        ax.set_title(f"psi({i+1},{j+1}) at x{j+1}={xj_fixed:.2f}")

        ax.set_ylim(-0.15, 0.15)

        fig.tight_layout()

        if out_dir is not None:
            fname = f"{prefix}_psi{i+1}{j+1}_x{j+1}_{xj_fixed:.2f}.png".replace(".", "p")
            fig.savefig(out_dir / fname, dpi=200)
            plt.close(fig)
        else:
            figs.append(fig)
            if show:
                plt.show()

    return figs if out_dir is None else None


def plot_interaction_3d(
    model,
    i: int,
    j: int,
    psi_true_fn: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None,
    gridN: int = 61,
    center: bool = True,
    save_dir: Optional[str] = None,
    prefix: str = "psi3d",
    show: bool = False,
):
    """
    3D surface of psi_{ij}. If psi_true_fn is provided, also plots true and error heatmap.
    Saves PNGs if save_dir is provided; otherwise returns figures.
    """
    out_dir = _ensure_dir(save_dir)

    key = (min(i, j), max(i, j))
    if key not in model.Phi_pair:
        raise ValueError(f"Interaction {key} not found in model (no pair basis).")

    Phi_i, Phi_j = model.Phi_pair[key]
    x = np.linspace(0.0, 1.0, gridN)
    xi, xj = np.meshgrid(x, x, indexing="ij")

    Bi = Phi_i(xi.reshape(-1))
    Bj = Phi_j(xj.reshape(-1))
    Bij = (Bi[:, :, None] * Bj[:, None, :]).reshape(Bi.shape[0], -1)

    psi_hat = (Bij @ model.w_psi[key]).reshape(gridN, gridN)
    if center:
        psi_hat = psi_hat - np.mean(psi_hat)

    figs = []

    # --- estimated surface ---
    fig = plt.figure(figsize=(5, 4))
    ax = fig.add_subplot(111, projection="3d")
    ax.plot_surface(xi, xj, psi_hat, rstride=2, cstride=2, linewidth=0.0, antialiased=True)
    ax.set_xlabel(f"x{i+1}")
    ax.set_ylabel(f"x{j+1}")
    ax.set_zlabel("psi_hat")
    ax.set_title(f"Estimated psi({i+1},{j+1})")
    fig.tight_layout()

    if out_dir is not None:
        fig.savefig(out_dir / f"{prefix}_hat_{i+1}{j+1}.png", dpi=200)
        plt.close(fig)
    else:
        figs.append(fig)
        if show:
            plt.show()

    if psi_true_fn is None:
        return figs if out_dir is None else None

    # --- true surface ---
    psi_true = psi_true_fn(xi, xj)
    if center:
        psi_true = psi_true - np.mean(psi_true)

    fig2 = plt.figure(figsize=(5, 4))
    ax2 = fig2.add_subplot(111, projection="3d")
    ax2.plot_surface(xi, xj, psi_true, rstride=2, cstride=2, linewidth=0.0, antialiased=True)
    ax2.set_xlabel(f"x{i+1}")
    ax2.set_ylabel(f"x{j+1}")
    ax2.set_zlabel("psi_true")
    ax2.set_title(f"True psi({i+1},{j+1})")
    fig2.tight_layout()

    if out_dir is not None:
        fig2.savefig(out_dir / f"{prefix}_true_{i+1}{j+1}.png", dpi=200)
        plt.close(fig2)
    else:
        figs.append(fig2)
        if show:
            plt.show()

    # --- error heatmap ---
    err = np.abs(psi_hat - psi_true)
    fig3, ax3 = plt.subplots(figsize=(4.5, 3.6))
    im = ax3.imshow(err, origin="lower", extent=[0, 1, 0, 1], aspect="auto")
    ax3.set_xlabel(f"x{i+1}")
    ax3.set_ylabel(f"x{j+1}")
    ax3.set_title(f"|psi_hat - psi_true| for ({i+1},{j+1})")
    fig3.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
    fig3.tight_layout()

    if out_dir is not None:
        fig3.savefig(out_dir / f"{prefix}_err_{i+1}{j+1}.png", dpi=200)
        plt.close(fig3)
    else:
        figs.append(fig3)
        if show:
            plt.show()

    return figs if out_dir is None else None
