# Graphing Parameters
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch

mpl.rcParams['lines.markersize'] = 12
mpl.rcParams['lines.linewidth'] = 1
mpl.rcParams['xtick.labelsize'] = 25
mpl.rcParams['ytick.labelsize'] = 25
mpl.rcParams["axes.labelsize"] = 25
mpl.rcParams['legend.fontsize'] = 20
mpl.rcParams['axes.titlesize'] = 25
mpl.rcParams['text.usetex'] = False

def plot_mse_with_bands(arr, N, true_value,
                        q=(0.1, 0.9),
                        logx=True,
                        logy=False,
                        show=True):
    """
    Plot MSE vs N with empirical deviation bands.

    Parameters
    ----------
    arr : np.ndarray
        Shape (R, K, 4)
          arr[:, :, 0] = raw estimate
          arr[:, :, 1] = bias-corrected estimate
    N : array-like
        Shape (K,)
    true_value : float
        True estimand value
    q : tuple
        Quantiles for deviation bands (e.g. (0.1, 0.9) or (0.025, 0.975))
    """

    arr = np.asarray(arr)
    N = np.asarray(N)

    raw = arr[:, :, 0]
    bc  = arr[:, :, 1]

    # Squared errors: shape (R, K)
    se_raw = (raw - true_value) ** 2
    se_bc  = (bc  - true_value) ** 2

    # Mean MSE
    mse_raw = se_raw.mean(axis=0)
    mse_bc  = se_bc.mean(axis=0)

    # Quantile bands
    raw_lo, raw_hi = np.quantile(se_raw, q, axis=0)
    bc_lo,  bc_hi  = np.quantile(se_bc,  q, axis=0)

    # Asymmetric errors for errorbar
    raw_err = np.vstack([mse_raw - raw_lo, raw_hi - mse_raw])
    bc_err  = np.vstack([mse_bc  - bc_lo,  bc_hi  - mse_bc ])

    fig, ax = plt.subplots(figsize=(8, 5))
    
    # Plot means
    ax.plot(N, mse_raw, marker='s', linewidth=2, color='brown', label="Plug-in")
    ax.plot(N, mse_bc,  marker='*', linewidth=2, color='teal',  label="One-step")

    # Capped vertical error bars
    ax.errorbar(N, mse_raw, yerr=raw_err, fmt='none',
                ecolor='brown', elinewidth=3.5, capsize=5, alpha=0.6)

    ax.errorbar(N, mse_bc, yerr=bc_err, fmt='none',
                ecolor='teal', elinewidth=2.5, capsize=5, alpha=0.6)


    if logx:
        ax.set_xscale("log")
    if logy:
        ax.set_yscale("log")

    ax.set_xlim(N.min()-50, N.max()+50)
    ax.set_xticks(N)
    ax.set_xticklabels([str(n) for n in N])
    ax.tick_params(axis="x", which="major", length=6)


    ax.set_xlabel(f"Sample size ($n$)")
    ax.set_ylabel("MSE")
    ax.legend()

    fig.tight_layout()
    if show:
        plt.show()

    return fig, ax

def plot_single_run_per_n_with_ci(arr, n_grid, true_value, *,
                                 seed=0,
                                 logx=True,
                                 show=True):
    """
    For each n (index k), randomly select one MC run r_k and plot:
      - raw estimate (arr[r_k, k, 0])
      - debiased estimate (arr[r_k, k, 1])
      - Wald CI band from (arr[r_k, k, 2], arr[r_k, k, 3])
      - horizontal line at true_value

    Parameters
    ----------
    arr : np.ndarray
        Shape (R, K, 4):
          [:, :, 0] raw estimate
          [:, :, 1] debiased estimate
          [:, :, 2] Wald CI lower (for debiased)
          [:, :, 3] Wald CI upper (for debiased)
    n_grid : array-like
        Shape (K,)
    true_value : float
    seed : int
        RNG seed for reproducibility
    logx : bool
        Use log scale for x-axis
    """

    arr = np.asarray(arr)
    n_grid = np.asarray(n_grid)

    if arr.ndim != 3 or arr.shape[2] != 4:
        raise ValueError(f"Expected arr shape (R, K, 4). Got {arr.shape}")
    R, K, _ = arr.shape
    if len(n_grid) != K:
        raise ValueError(f"len(n_grid) must equal K={K}. Got {len(n_grid)}")

    rng = np.random.default_rng(seed)

    # Pick one random run index per n (per column k)
    r_idx = rng.integers(low=0, high=R, size=K)

    raw = arr[r_idx, np.arange(K), 0]
    deb = arr[r_idx, np.arange(K), 1]
    lo  = arr[r_idx, np.arange(K), 2]
    hi  = arr[r_idx, np.arange(K), 3]

    # If any selected run yields invalid CI bounds, keep the x-grid but mask y-values
    raw = np.where(np.isfinite(raw), raw, np.nan)
    deb = np.where(np.isfinite(deb), deb, np.nan)
    lo  = np.where(np.isfinite(lo),  lo,  np.nan)
    hi  = np.where(np.isfinite(hi),  hi,  np.nan)

    fig, ax = plt.subplots(figsize=(8.5, 5))

    # True value
    ax.axhline(true_value, linestyle="--", linewidth=1.5, label="Truth", color='black')

    # Wald CI band (for debiased)
    ax.fill_between(n_grid, lo, hi, alpha=0.25, label="Wald 95% CI", color='teal')

    # Raw and debiased estimates
    ax.plot(n_grid, raw, linewidth=2, color='brown', label="Plug-in", marker='s')
    ax.plot(n_grid, deb, linewidth=2,  color='teal', label="One-step", marker='*')

    if logx:
        ax.set_xscale("log")

    ax.set_xlim(n_grid.min() * 0.9, n_grid.max() * 1.1)
    ax.set_xticks(n_grid)
    ax.set_xticklabels([str(n) for n in n_grid])
    ax.tick_params(axis="x", which="major", length=6)

    ax.set_xlim(n_grid.min(), n_grid.max())
    ax.set_xlabel(f"Sample size ($n$)")
    ax.set_ylabel("DTE")
    ax.legend(
        loc="lower center",
        bbox_to_anchor=(0.5, -0.5),
        ncol=4,
    )

    # fig.tight_layout()
    if show:
        plt.show()

    return fig, ax, r_idx


def plot_coverage_vs_n(arr_ci, n_grid, G_true, *,
                       logx=True,
                       show=True):
    """
    Coverage plot for Wald 95% CIs vs sample size.

    Parameters
    ----------
    arr_ci : np.ndarray
        Shape (R, K, 2) where:
          arr_ci[:, :, 0] = lower CI
          arr_ci[:, :, 1] = upper CI
    n_grid : array-like
        Shape (K,), sample sizes.
    G_true : float
        True estimand value.
    """
    arr_ci = np.asarray(arr_ci)
    n_grid = np.asarray(n_grid)

    if arr_ci.ndim != 3 or arr_ci.shape[2] != 2:
        raise ValueError(f"Expected arr_ci shape (R, K, 2). Got {arr_ci.shape}")
    R, K, _ = arr_ci.shape
    if len(n_grid) != K:
        raise ValueError(f"len(n_grid) must equal K={K}. Got {len(n_grid)}")

    lo = arr_ci[:, :, 0]
    hi = arr_ci[:, :, 1]

    # Coverage indicator per (r, k)
    covered = (lo <= G_true) & (G_true <= hi)

    # If you might have NaNs/Infs from blow-ups, ignore them
    valid = np.isfinite(lo) & np.isfinite(hi)
    covered = np.where(valid, covered, np.nan)

    coverage = np.nanmean(covered, axis=0)  # (K,)

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(n_grid, coverage, linewidth=2, marker="o", color='brown', label="Empirical coverage")
    ax.axhline(0.95, linestyle="--", linewidth=1.5, label="Nominal 95%", color='teal')

    if logx:
        ax.set_xscale("log")

    ax.set_xlim(n_grid.min() * 0.9, n_grid.max() * 1.1)
    ax.set_xticks(n_grid)
    ax.set_xticklabels([str(n) for n in n_grid])
    ax.tick_params(axis="x", which="major", length=6)

    # ax.set_ylim(0.0, 1.0)
    ax.set_xlabel(r"Sample size ($n$)")
    ax.set_ylabel("Coverage probability")
    ax.legend()

    fig.tight_layout()
    if show:
        plt.show()

    return fig, ax, coverage

