
"""
Windy Gridworld + JIPE Demo

What this script does
---------------------
Defines a 3x3 "windy gridworld" with a shared per-step wind variable.
Evaluates a fixed policy π: greedily move toward a given goal.
Computes analytic ground truth for V, S2, and Cross via linear solves.
Computes JIPE estimates by repeatedly applying the joint operator.
Builds per-action means, uncentered second moments, and covariances at any state.

Policy being evaluated
----------------------
Throughout, the value functions (V, S2, Cross) are for the fixed deterministic policy pi
that moves greedily toward the goal: if x < goal_x go Right; if x > goal_x go Left;
else move vertically toward goal_y. The goal is absorbing and gives reward=1 when entered.

"""

from typing import Tuple, List
import numpy as np

ACTIONS = ["R","L","U","D"]
DELTA = {"R": (1,0), "L": (-1,0), "U": (0,1), "D": (0,-1)}

class WindyGridSpec:
    width: int = 3
    height: int = 3
    goal: Tuple[int,int] = (2,2)      # Goal cell (x,y)
    p_wind: float = 0.35              # P(shared wind gust) each step
    gamma: float = 0.95               # Discount


def clamp(x, lo, hi): return max(lo, min(hi, x))
def idx(x,y,w): return y*w + x
def xy(i,w): return (i % w, i // w)
def is_terminal(s, spec: WindyGridSpec): return xy(s, spec.width) == spec.goal
def reward_from(x,y,nx,ny, spec: WindyGridSpec): return 1.0 if (nx,ny)==spec.goal else 0.0


def greedy_policy_to_goal(x,y, goal):
    gx, gy = goal
    if (x,y)==(gx,gy): return None  # terminal
    if x < gx: return "R"
    if x > gx: return "L"
    if y < gy: return "U"
    if y > gy: return "D"
    return None


def next_state_with_wind(x,y, action, wind_on, spec: WindyGridSpec):
    dx, dy = DELTA[action]
    wx, wy = (-1, 0) if wind_on else (0,0)  # shared gust pushes left by 1 when on
    nx = clamp(x + dx + wx, 0, spec.width-1)
    ny = clamp(y + dy + wy, 0, spec.height-1)
    return nx, ny


def build_policy(spec: WindyGridSpec) -> List[str]:
    w, h = spec.width, spec.height
    pi = []
    for s in range(w*h):
        x,y = xy(s,w)
        pi.append(greedy_policy_to_goal(x,y, spec.goal))
    return pi

def action_outcomes(s, action, spec: WindyGridSpec):
    p = spec.p_wind
    w = spec.width
    x,y = xy(s,w)
    outs = []
    for wind_on, pw in [(0,1-p), (1,p)]:
        nx,ny = next_state_with_wind(x,y,action,wind_on,spec)
        sp = idx(nx,ny,w)
        r = reward_from(x,y,nx,ny,spec)
        outs.append((pw, r, sp))
    return outs


def build_P_R_for_policy(spec: WindyGridSpec):
    nS = spec.width*spec.height
    pi = build_policy(spec)
    P = np.zeros((nS,nS))
    R = np.zeros(nS)
    for s in range(nS):
        if is_terminal(s,spec) or pi[s] is None:
            P[s,s] = 1.0
            R[s] = 0.0
        else:
            for pw, r, sp in action_outcomes(s, pi[s], spec):
                P[s,sp] += pw
                R[s]    += pw*r
    return P, R, pi

def solve_ground_truth(spec: WindyGridSpec):
    g = spec.gamma
    P, R, pi = build_P_R_for_policy(spec)
    I = np.eye(P.shape[0])

    # Value
    V = np.linalg.solve(I - g*P, R)

    # Uncentered second moment S2
    M = np.zeros_like(R)
    for s in range(P.shape[0]):
        if is_terminal(s,spec) or pi[s] is None:
            M[s] = 0.0
        else:
            M[s] = sum(pw * r * V[sp] for pw, r, sp in action_outcomes(s, pi[s], spec))
    S2 = np.linalg.solve(I - (g**2)*P, R + 2*g*M)

    # Cross over state pairs under shared wind
    nS = P.shape[0]
    n2 = nS*nS
    Pj = np.zeros((n2,n2))
    B  = np.zeros(n2)
    for s1 in range(nS):
        x1,y1 = xy(s1, spec.width)
        a1 = pi[s1] if not is_terminal(s1,spec) else None
        for s2 in range(nS):
            x2,y2 = xy(s2, spec.width)
            a2 = pi[s2] if not is_terminal(s2,spec) else None
            k = s1*nS + s2
            if a1 is None or a2 is None:
                Pj[k,k] = 1.0
                B[k] = 0.0
                continue
            for wind_on, pw in [(0,1-spec.p_wind), (1,spec.p_wind)]:
                nx1,ny1 = next_state_with_wind(x1,y1,a1,wind_on,spec)
                nx2,ny2 = next_state_with_wind(x2,y2,a2,wind_on,spec)
                sp1 = idx(nx1,ny1, spec.width)
                sp2 = idx(nx2,ny2, spec.width)
                k2  = sp1*nS + sp2
                Pj[k,k2] += pw
                r1 = reward_from(x1,y1,nx1,ny1,spec)
                r2 = reward_from(x2,y2,nx2,ny2,spec)
                B[k] += pw*(r1*r2 + g*r1*V[sp2] + g*r2*V[sp1])
    Cross = np.linalg.solve(np.eye(n2) - (g**2)*Pj, B).reshape((nS,nS))
    return V, S2, Cross


def jipe_iterative(spec: WindyGridSpec, tol=1e-12, max_it=10):
    g = spec.gamma
    P, R, pi = build_P_R_for_policy(spec)
    nS = P.shape[0]

    # V iteration
    V = np.zeros(nS)
    for _ in range(max_it):
        Vn = R + g*(P @ V)
        if np.max(np.abs(Vn - V)) < tol:
            V = Vn; break
        V = Vn

    # S2 iteration
    S2 = np.zeros(nS)
    for _ in range(max_it):
        S2n = np.zeros_like(S2)
        for s in range(nS):
            if is_terminal(s,spec) or pi[s] is None:
                S2n[s] = 0.0
            else:
                S2n[s] = sum(pw*(r + 2*g*r*V[sp] + (g**2)*S2[sp]) for pw,r,sp in action_outcomes(s, pi[s], spec))
        if np.max(np.abs(S2n - S2)) < tol:
            S2 = S2n; break
        S2 = S2n

    # Cross iteration (state-pair)
    Cross = np.zeros((nS,nS))
    for _ in range(max_it):
        Cn = np.zeros_like(Cross)
        for s1 in range(nS):
            x1,y1 = xy(s1, spec.width)
            a1 = pi[s1] if not is_terminal(s1,spec) else None
            for s2 in range(nS):
                x2,y2 = xy(s2, spec.width)
                a2 = pi[s2] if not is_terminal(s2,spec) else None
                if a1 is None or a2 is None:
                    Cn[s1,s2] = 0.0
                    continue
                exp = 0.0
                for wind_on, pw in [(0,1-spec.p_wind), (1,spec.p_wind)]:
                    nx1,ny1 = next_state_with_wind(x1,y1,a1,wind_on,spec)
                    nx2,ny2 = next_state_with_wind(x2,y2,a2,wind_on,spec)
                    sp1 = idx(nx1,ny1, spec.width)
                    sp2 = idx(nx2,ny2, spec.width)
                    r1 = reward_from(x1,y1,nx1,ny1,spec)
                    r2 = reward_from(x2,y2,nx2,ny2,spec)
                    exp += pw*(r1*r2 + g*r1*V[sp2] + g*r2*V[sp1] + (g**2)*Cross[sp1,sp2])
                Cn[s1,s2] = exp
        if np.max(np.abs(Cn - Cross)) < tol:
            Cross = Cn; break
        Cross = Cn
    return V, S2, Cross


def per_action_moments_at_state(s, spec: WindyGridSpec, V, S2, Cross):
    g = spec.gamma
    mu   = np.zeros(4)
    Sbar = np.zeros((4,4))

    # means and diagonals
    for i, a in enumerate(ACTIONS):
        outs = action_outcomes(s, a, spec)
        mu[i] = sum(pw*(r + g*V[sp]) for pw,r,sp in outs)
        Sbar[i,i] = sum(pw*(r + 2*g*r*V[sp] + (g**2)*S2[sp]) for pw,r,sp in outs)

    # off-diagonals with shared wind coupling
    x,y = xy(s, spec.width)
    for i, a1 in enumerate(ACTIONS):
        for j, a2 in enumerate(ACTIONS):
            if j <= i: continue
            exp = 0.0
            for wind_on, pw in [(0,1-spec.p_wind), (1,spec.p_wind)]:
                nx1,ny1 = next_state_with_wind(x,y,a1,wind_on,spec)
                nx2,ny2 = next_state_with_wind(x,y,a2,wind_on,spec)
                sp1 = idx(nx1,ny1, spec.width)
                sp2 = idx(nx2,ny2, spec.width)
                r1 = reward_from(x,y,nx1,ny1,spec)
                r2 = reward_from(x,y,nx2,ny2,spec)
                exp += pw*(r1*r2 + g*r1*V[sp2] + g*r2*V[sp1] + (g**2)*Cross[sp1,sp2])
            Sbar[i,j] = exp
            Sbar[j,i] = exp

    Sigma = Sbar - np.outer(mu, mu)
    return mu, Sbar, Sigma

def covariance_to_correlation(cov_matrix):
    cov_matrix = np.array(cov_matrix)
    
    # Extract diagonal elements (variances)
    variances = np.diag(cov_matrix)
    
    # Calculate standard deviations
    std_devs = np.sqrt(variances)
    
    # Create correlation matrix: corr_ij = cov_ij / (std_i * std_j)
    correlation_matrix = cov_matrix / np.outer(std_devs, std_devs)
    
    return correlation_matrix


def run_demo(
    width=3, height=3,
    goal=(2,2), p_wind=0.35, gamma=0.95,
    state_xy=(0,0), max_it=20   # middle cell by default
):
    spec = WindyGridSpec(width=width, height=height, goal=goal, p_wind=p_wind, gamma=gamma)
    s = idx(state_xy[0], state_xy[1], spec.width)

    # Analytic truth
    V_t, S2_t, Cross_t = solve_ground_truth(spec)
    mu_t, Sbar_t, Sigma_t = per_action_moments_at_state(s, spec, V_t, S2_t, Cross_t)
    Sigma_t = covariance_to_correlation(Sigma_t)

    # Iterative JIPE
    V_e, S2_e, Cross_e = jipe_iterative(spec, max_it=max_it)
    mu_e, Sbar_e, Sigma_e = per_action_moments_at_state(s, spec, V_e, S2_e, Cross_e)
    Sigma_e = covariance_to_correlation(Sigma_e)

    # Print
    def fmt(M): return np.array2string(M, precision=8, floatmode="fixed", suppress_small=False)
    print(f"Spec: grid={width}x{height}, goal={goal}, p_wind={p_wind}, gamma={gamma}, state={state_xy}")
    print("\nPolicy evaluated: deterministic, Manhattan-greedy toward the goal (goal is absorbing, reward=1 on entry).")
    print("\n--- Means [R L U D] ---")
    print("True:\n", fmt(mu_t))
    print("JIPE:\n", fmt(mu_e))
    print("||Δμ||_∞ =", np.max(np.abs(mu_t - mu_e)))

    print("\n--- Correlation [R L U D] ---")
    print("True:\n", fmt(Sigma_t))
    print("JIPE:\n", fmt(Sigma_e))
    print("||ΔΣ||_∞ =", np.max(np.abs(Sigma_t - Sigma_e)))

if __name__ == "__main__":
    run_demo()
