
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap

EPSILON_ACTIONS = np.array([0.5, 1.0, 2.0], dtype=float)
ACTION_LABELS = [r"$\dfrac{\text{Kelly}}{2}$", r"$\text{Kelly}$", r"$\text{All-in}$"]



def safe_bounds(m, eps_cap=1e-3):
    """Return (lam_max_pos, lam_max_neg) ensuring 1+λ(X-m) ≥ eps_cap for all X in [0,1]."""
    lam_max_pos = (1.0 - eps_cap) / m
    lam_max_neg = -(1.0 - eps_cap) / (1.0 - m)
    return float(lam_max_pos), float(lam_max_neg)


def sample_beta_mixture(rng, size, mu, conc=6.0):
    """50/50 mixture of Beta(conc*mu, conc*(1-mu)) and Beta(2*conc*mu, 2*conc*(1-mu))."""
    a1, b1 = max(1e-8, conc * mu), max(1e-8, conc * (1.0 - mu))
    a2, b2 = max(1e-8, 2.0 * conc * mu), max(1e-8, 2.0 * conc * (1.0 - mu))
    print(a1, b1, a2, b2)
    mask = rng.random(size) < 0.5
    X = np.empty(size, dtype=np.float32)
    X[mask] = rng.beta(a1, b1, size=mask.sum()).astype(np.float32)
    X[~mask] = rng.beta(a2, b2, size=(~mask).sum()).astype(np.float32)
    return X


def kelly_bet_mc(X, m, lam_min, lam_max, max_iter=80, tol=1e-10):
    """
    Approximate Kelly bet by solving E[(X-m)/(1+λ(X-m))]=0 via bisection on MC samples.
    The log-growth objective is concave, so the derivative is monotone.
    """
    X = np.asarray(X, dtype=np.float64)
    dm = X - float(m)

    def deriv(lam):
        return np.mean(dm / (1.0 + lam * dm))

    g_lo = deriv(lam_min)
    g_hi = deriv(lam_max)

    # If derivative doesn't cross 0, optimum is at a boundary.
    if g_lo <= 0.0:
        return float(lam_min)
    if g_hi >= 0.0:
        return float(lam_max)

    lo, hi = float(lam_min), float(lam_max)
    for _ in range(int(max_iter)):
        mid = 0.5 * (lo + hi)
        g_mid = deriv(mid)
        if abs(g_mid) < tol:
            return float(mid)
        # derivative decreases in λ
        if g_mid > 0.0:
            lo = mid
        else:
            hi = mid
    return float(0.5 * (lo + hi))


def max_increment_for_lambda(lam, m):
    """Maximum possible log-increment log(1+λ(X-m)) over X∈[0,1]."""
    lam = float(lam)
    m = float(m)
    x_best = 1.0 if lam >= 0 else 0.0
    return float(np.log(1.0 + lam * (x_best - m)))


def dp_optimal_policy_discrete(
    N,
    alpha,
    m,
    mu,
    conc=6.0,
    world="beta_mixture",
    eps_cap=1e-3,
    num_y_bins=120,
    mc_samples=6000,
    seed=0,
    y_margin=0.75,
):
    """
    Finite-horizon DP with Bellman recursion on state (t, y=log-wealth), using 3 discrete actions:
      a=0: 0.5 * Kelly
      a=1: 1.0 * Kelly
      a=2: all-in (directional endpoint)

    Returns a dict with:
      - y_edges, y_centers
      - policy_idx (Ybins, N) with -1 meaning "irrelevant" (above threshold or unreachable)
      - V (Ybins, N+1) value function grid
      - lambdas, lam_kelly, lam_end, bounds, T
    """
    rng = np.random.default_rng(seed)
    T = float(np.log(1.0 / alpha))

    lam_max_pos, lam_max_neg = safe_bounds(m, eps_cap=eps_cap)

    # Kelly bet (known distribution): solve derivative condition via MC+bisection
    if world == "beta_mixture":
        X_kelly = sample_beta_mixture(rng, size=200000, mu=mu, conc=conc)
    elif world == "beta":
        X_kelly = rng.beta(max(1e-8, conc * mu), max(1e-8, conc * (1 - mu)), size=200000).astype(
            np.float32
        )
    else:
        raise ValueError("world must be 'beta' or 'beta_mixture'")

    lam_kelly = kelly_bet_mc(X_kelly, m, lam_min=lam_max_neg, lam_max=lam_max_pos)

    # directional all-in
    lam_end = lam_max_pos if (mu - m) >= 0 else lam_max_neg

    lambdas = np.array([0.5 * lam_kelly, lam_kelly, lam_end], dtype=np.float64)
    lambdas = np.clip(lambdas, lam_max_neg, lam_max_pos)

    # y-grid: include all reachable states + a margin
    Jmax = max(max_increment_for_lambda(lam, m) for lam in lambdas)
    y_min = T - N * Jmax - float(y_margin)
    y_max = T + float(y_margin)

    y_edges = np.linspace(y_min, y_max, int(num_y_bins) + 1, dtype=np.float64)
    y_centers = 0.5 * (y_edges[:-1] + y_edges[1:])

    # MC samples for the DP expectation (fixed sample for stability)
    if world == "beta_mixture":
        X_dp = sample_beta_mixture(rng, size=int(mc_samples), mu=mu, conc=conc).astype(np.float64)
    else:
        X_dp = rng.beta(max(1e-8, conc * mu), max(1e-8, conc * (1 - mu)), size=int(mc_samples)).astype(
            np.float64
        )

    # Precompute log-increments for each discrete action
    h = np.stack([np.log1p(lam * (X_dp - m)) for lam in lambdas], axis=0)  # (3, S)

    A = len(lambdas)
    M = len(y_centers)

    V = np.zeros((M, N + 1), dtype=np.float64)
    policy_idx = np.full((M, N), fill_value=-1, dtype=np.int32)

    # Terminal value
    V[:, N] = (y_centers >= T).astype(np.float64)

    # Backward recursion:
    # V_t(y) = max_a E[ V_{t+1}( y + log(1+λ_a(X-m)) ) ]
    for t in range(N - 1, -1, -1):
        V_next = V[:, t + 1]
        action_values = np.empty((A, M), dtype=np.float64)

        for a in range(A):
            y_next = y_centers[:, None] + h[a][None, :]  # (M,S)
            v_interp_flat = np.interp(y_next.ravel(), y_centers, V_next, left=0.0, right=1.0)
            v_interp = v_interp_flat.reshape(M, -1)
            action_values[a] = v_interp.mean(axis=1)

        best_a = action_values.argmax(axis=0)
        best_v = action_values.max(axis=0)

        # Irrelevant: already above threshold
        above = y_centers >= T
        best_v[above] = 1.0
        best_a[above] = -1

        # Unreachable: even max gain each remaining step can't hit threshold
        impossible = y_centers + (N - t) * Jmax < T
        best_v[impossible] = 0.0
        best_a[impossible] = -1

        policy_idx[:, t] = best_a
        V[:, t] = best_v

    return dict(
        T=T,
        bounds=(lam_max_neg, lam_max_pos),
        lambdas=lambdas,
        lam_kelly=lam_kelly,
        lam_end=lam_end,
        Jmax=Jmax,
        y_edges=y_edges,
        y_centers=y_centers,
        policy_idx=policy_idx,
        V=V,
    )


def _slice_rows_by_y(info, y_plot_min=None, y_plot_max=None):
    """
    Select a contiguous subset of y-rows for plotting, based on y_centers.
    Returns (row_slice, y_edges_plot).
    """
    y_centers = info["y_centers"]
    M = len(y_centers)

    lo = 0
    hi = M

    if y_plot_min is not None:
        y_plot_min = float(y_plot_min)
        lo = int(np.searchsorted(y_centers, y_plot_min, side="left"))
        lo = max(0, min(lo, M - 1))

    if y_plot_max is not None:
        y_plot_max = float(y_plot_max)
        hi = int(np.searchsorted(y_centers, y_plot_max, side="right"))
        hi = max(lo + 1, min(hi, M))

    y_edges = info["y_edges"]
    y_edges_plot = y_edges[lo : hi + 1]
    return slice(lo, hi), y_edges_plot



def plot_dp_value_grid(
    info,
    filename="dp_value_grid.png",
    title=None,
    include_terminal=False,
    y_plot_min=None,
    y_plot_max=None,
    vmin=0.0,
    vmax=1.0,
):
    """
    Plot DP value function grid V_t(y).

    info["V"] has shape (Ybins, N+1).

    include_terminal=False  -> plot t = 0..N-1 (decision times)
    include_terminal=True   -> plot t = 0..N   (includes terminal VN column)
    """
    V = info["V"]
    M, Np1 = V.shape
    N = Np1 - 1
    T = info["T"]

    if include_terminal:
        Z = V                      # (M, N+1)
        t_edges = np.arange(N + 2) # edges length (N+1)+1
        xlabel = r"$t$ (including terminal)"
    else:
        Z = V[:, :N]               # (M, N)
        t_edges = np.arange(N + 1) # edges length N+1
        xlabel = r"$t$"

    # crop y for plotting only
    row_slice, y_edges_plot = _slice_rows_by_y(info, y_plot_min=y_plot_min, y_plot_max=y_plot_max)
    Z_plot = Z[row_slice, :]

    # -----------------------
    # MATCH policy plot style
    # -----------------------
    bold = "bold"
    base_fs   = 14
    label_fs  = 18
    tick_fs   = 14
    title_fs  = 18
    legend_fs = 14
    cbar_fs   = 18

    plt.rcParams.update({
        "font.size": base_fs,
        "axes.labelsize": label_fs,
        "axes.titlesize": title_fs,
        "xtick.labelsize": tick_fs,
        "ytick.labelsize": tick_fs,
        "axes.linewidth": 0.8,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })

    # Match policy figure size
    fig, ax = plt.subplots(figsize=(5.9, 4.4), constrained_layout=True)

    im = ax.pcolormesh(
        t_edges, y_edges_plot, Z_plot,
        shading="auto",
        vmin=vmin, vmax=vmax,
        rasterized=True,
    )

    ax.set_xlabel(xlabel, labelpad=1, fontsize=label_fs, fontweight=bold)
    ax.set_ylabel(r"$y=\log W_t$", labelpad=1, fontsize=label_fs, fontweight=bold)

    ax.tick_params(axis="both", which="major", pad=2, labelsize=tick_fs)
    for ticklab in ax.get_xticklabels() + ax.get_yticklabels():
        ticklab.set_fontweight(bold)

    
    #if title is not None:
    #    ax.set_title(title, fontweight=bold)
    #else:
    #    ax.set_title(r"DP value function grid $V_t(y)$", fontweight=bold)

    # Threshold line (same styling vibe as policy)
    thr_gray = "0.40"
    ax.axhline(T, linestyle="--", linewidth=1.4, color=thr_gray, alpha=0.95)

    # Keep your legend, but scale it to match
    line = plt.Line2D([0], [0], color=thr_gray, lw=1.4, ls="--")
    leg = ax.legend([line], [r"threshold $\log(1/\alpha)$"], loc="best",
                    fontsize=legend_fs, frameon=True)
    for txt in leg.get_texts():
        txt.set_fontweight(bold)

    cbar = fig.colorbar(im, ax=ax, fraction=0.06, pad=0.01)
    cbar.set_label(
        r"$\Pr(\mathrm{hit\ by\ }N \mid (t,y), \mathrm{optimal})$",
        fontsize=cbar_fs,
        fontweight=bold,
        labelpad=7,
        rotation=90,
    )
    cbar.ax.tick_params(labelsize=tick_fs, pad=3)
    for ticklab in cbar.ax.get_yticklabels():
        ticklab.set_fontweight(bold)

    fig.savefig(filename, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return filename



def plot_dp_policy_grid(
    info,
    filename="dp_optimal_policy_grid.pdf",
    title=None,                 # recommend: keep None; put title in caption
    y_plot_min=None,
    y_plot_max=None,
    show_contours=False,
):
    policy_idx = info["policy_idx"]  # (Ybins, N) with -1 for irrelevant/unreachable
    N = policy_idx.shape[1]
    T = info["T"]

    # Crop rows for plotting only
    row_slice, y_edges_plot = _slice_rows_by_y(info, y_plot_min=y_plot_min, y_plot_max=y_plot_max)
    P = policy_idx[row_slice, :]  # (Mplot, N)

    # Mask invalid cells (-1)
    A = np.ma.masked_where(P < 0, P)

    # Use paper palette
    aggr_fill = "#E69F00"   # All-in
    kelly_fill = "#56B4E9"  # Kelly
    less_fill = "#009E73"   # 0.5×Kelly
    dead_fill = "1.0"       # masked / irrelevant

    # 0: 0.5×Kelly, 1: Kelly, 2: All-in
    cmap = ListedColormap([less_fill, kelly_fill, aggr_fill], name="actions3")
    cmap.set_bad(dead_fill)
    norm = BoundaryNorm([-0.5, 0.5, 1.5, 2.5], ncolors=cmap.N)

    # Time edges
    t_edges = np.arange(N + 1)

    # -----------------------
    # BIG + BOLD typography
    # -----------------------
    bold = "bold"
    base_fs   = 14
    label_fs  = 18
    tick_fs   = 14
    title_fs  = 18
    annot_fs  = 14
    cbar_fs   = 18

    plt.rcParams.update({
        "font.size": base_fs,
        "axes.labelsize": label_fs,
        "axes.titlesize": title_fs,
        "xtick.labelsize": tick_fs,
        "ytick.labelsize": tick_fs,
        "axes.linewidth": 0.8,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })

    fig, ax = plt.subplots(figsize=(5.9, 4.4), constrained_layout=True)

    # Rasterize the heavy part so PDF stays light
    im = ax.pcolormesh(
        t_edges,
        y_edges_plot,
        A,
        cmap=cmap,
        norm=norm,
        shading="auto",
        rasterized=True,
    )

    ax.set_xlabel(r"$t$", labelpad=1, fontsize=label_fs, fontweight=bold)
    ax.set_ylabel(r"$y=\log W_t$", labelpad=1, fontsize=label_fs, fontweight=bold)

    ax.tick_params(axis="both", which="major", pad=2, labelsize=tick_fs)

    # Bold tick labels explicitly
    for ticklab in ax.get_xticklabels() + ax.get_yticklabels():
        ticklab.set_fontweight(bold)

    if title is not None:
        ax.set_title(title, fontweight=bold)

    line_gray = "0.35"
    thr_gray = "0.40"

    # Threshold line + direct label (no legend box)
    ax.axhline(T, linestyle="--", linewidth=1.4, color=thr_gray, alpha=0.95)
    ax.text(
        t_edges[-1], T,
        r"  Threshold $\log(1/\alpha)$",
        ha="right", va="bottom",
        fontsize=annot_fs, fontweight=bold
    )

    # Optional crisp regime boundaries
    if show_contours:
        y_cent = 0.5 * (y_edges_plot[:-1] + y_edges_plot[1:])
        t_cent = 0.5 * (t_edges[:-1] + t_edges[1:])
        TT, YY = np.meshgrid(t_cent, y_cent)
        ax.contour(
            TT, YY, A,
            levels=[0.5, 1.5],
            colors=line_gray,
            linewidths=1.2,
            alpha=0.85
        )

    ax.grid(False)

    # Colorbar with centered ticks
    cbar = fig.colorbar(im, ax=ax, ticks=[0, 1, 2], fraction=0.06, pad=0.01)
    cbar.ax.set_yticklabels(ACTION_LABELS)

    # Bold the colorbar tick labels explicitly
    for ticklab in cbar.ax.get_yticklabels():
        ticklab.set_fontweight(bold)
        ticklab.set_fontsize(tick_fs)

    cbar.set_label("Optimal Action", rotation=90, labelpad=7,
                   fontsize=cbar_fs)
    cbar.ax.tick_params(pad=3)
    cbar.outline.set_linewidth(0.9)

    fig.savefig(filename, bbox_inches="tight")
    fig.savefig(filename.replace(".png", ".pdf"), bbox_inches="tight")  # (kept as you had it)
    plt.close(fig)
    return filename


def _main():
    parser = argparse.ArgumentParser(description="DP baseline for discrete Kelly/all-in actions.")
    parser.add_argument("--N", type=int, default=200, help="Horizon length")
    parser.add_argument("--alpha", type=float, default=0.05, help="Significance level")
    parser.add_argument("--m", type=float, default=0.44, help="Null hypothesis parameter")
    parser.add_argument("--mu", type=float, default=0.40, help="True mean parameter")
    parser.add_argument("--conc", type=float, default=6, help="Beta concentration")
    parser.add_argument("--world", type=str, default="beta_mixture", choices=["beta", "beta_mixture"])
    parser.add_argument("--eps_cap", type=float, default=1e-3, help="Safety margin in bounds")
    parser.add_argument("--num_y_bins", type=int, default=360, help="Number of y bins for DP grid")
    parser.add_argument("--mc_samples", type=int, default=6000, help="Monte Carlo samples for DP expectation")
    parser.add_argument("--seed", type=int, default=0, help="RNG seed")
    parser.add_argument("--y_margin", type=float, default=0.75, help="Extra margin around threshold in y-grid")
    parser.add_argument("--plot", type=str, default="dp_optimal_policy_grid.png", help="Output policy plot filename")
    parser.add_argument("--value_plot", type=str, default='dp_value_grid.png', help="If set, also save DP value heatmap to this filename")
    parser.add_argument("--save_value_npz", type=str, default=None, help="If set, save V grid and y grid to this .npz file")
    parser.add_argument("--y_plot_min", type=float, default=-10, help="Plotting only: minimum y shown (crop)")
    parser.add_argument("--y_plot_max", type=float, default=None, help="Plotting only: maximum y shown (crop)")
    args = parser.parse_args()

    info = dp_optimal_policy_discrete(
        N=args.N,
        alpha=args.alpha,
        m=args.m,
        mu=args.mu,
        conc=args.conc,
        world=args.world,
        eps_cap=args.eps_cap,
        num_y_bins=args.num_y_bins,
        mc_samples=args.mc_samples,
        seed=args.seed,
        y_margin=args.y_margin,
    )
    plot_dp_policy_grid(
        info,
        filename=args.plot,
        y_plot_min=args.y_plot_min,
        y_plot_max=args.y_plot_max,
    )

    if args.value_plot is not None:
        plot_dp_value_grid(
            info,
            filename=args.value_plot,
            y_plot_min=args.y_plot_min,
            y_plot_max=args.y_plot_max,
        )

    if args.save_value_npz is not None:
        np.savez(
            args.save_value_npz,
            V=info["V"],
            y_centers=info["y_centers"],
            y_edges=info["y_edges"],
            T=info["T"],
            lambdas=info["lambdas"],
        )

    print("[dp] Saved optimal policy grid to", args.plot)
    print(
        f"[dp] Kelly={info['lam_kelly']:.4f}, endpoint={info['lam_end']:.4f}, "
        f"bounds=({info['bounds'][0]:.4f}, {info['bounds'][1]:.4f}), "
        f"T=log(1/alpha)={info['T']:.4f}"
    )


if __name__ == "__main__":
    _main()
