"""
Joint Iterative Policy Evaluation (JIPE) for CliffWalking

This script adapts the JIPE method to Gymnasium's CliffWalking environment with 
slippery dynamics that create shared randomness across actions.

It computes:
Value function V for a fixed policy
Uncentered second moment S2
Cross-covariances between state pairs under shared wind
Per-action expected returns and covariance matrices at each state

The CliffWalking environment with is_slippery=True:
- 4x12 grid (48 states)
- Start at (0,0), Goal at (11,0)
- Cliff region from (1,0) to (10,0)
- Falling off cliff gives -100 reward and returns to start
- Each step gives -1 reward
- With slippery dynamics: intended action happens with prob (1-wind_prob),
  otherwise a perpendicular action is taken (shared across all actions in that step)
"""

import numpy as np
from typing import Tuple, List, Dict
from dataclasses import dataclass
import gymnasium as gym

np.random.seed(1)

# Action mappings for CliffWalking
ACTIONS = ["U", "R", "D", "L"]  # Up=0, Right=1, Down=2, Left=3

@dataclass
class SlipperyCliffSpec:
    width: int = 12
    height: int = 4
    start: Tuple[int,int] = (0, 0)
    goal: Tuple[int,int] = (11, 0)
    wind_prob: float = 0.1  # Probability of slipping
    gamma: float = 0.95
    
def idx(x, y, w): 
    return y * w + x

def xy(s, w): 
    return (s % w, s // w)

def is_cliff(x, y):
    return y == 0 and 1 <= x <= 10

def is_terminal(s, spec):
    return xy(s, spec.width) == spec.goal

def get_slip_actions(action):
    """
    Get possible slip actions for slippery dynamics.
    In CliffWalking with is_slippery, when slip occurs,
    one of the OTHER three actions is randomly selected.
    """
    all_actions = [0, 1, 2, 3]
    return [a for a in all_actions if a != action]

def apply_action(x, y, action, spec):
    if action == 0:  # Up
        nx, ny = x, min(y + 1, spec.height - 1)
    elif action == 1:  # Right
        nx, ny = min(x + 1, spec.width - 1), y
    elif action == 2:  # Down
        nx, ny = x, max(y - 1, 0)
    else:  # Left (action == 3)
        nx, ny = max(x - 1, 0), y
    return nx, ny

def get_transition_outcomes(s, action, spec):
    x, y = xy(s, spec.width)
    outcomes = []
    
    if (x, y) == spec.goal:
        # Goal state is absorbing
        return [(1.0, 0.0, s)]
    
    # Non-slippery case: intended action happens
    nx, ny = apply_action(x, y, action, spec)
    if is_cliff(nx, ny):
        ns = idx(spec.start[0], spec.start[1], spec.width)
        outcomes.append((1 - spec.wind_prob, -100.0, ns))
    else:
        ns = idx(nx, ny, spec.width)
        outcomes.append((1 - spec.wind_prob, -1.0, ns))
    
    # Slippery case: one of the other three actions happens
    slip_actions = get_slip_actions(action)
    for slip_action in slip_actions:
        nx, ny = apply_action(x, y, slip_action, spec)
        if is_cliff(nx, ny):
            ns = idx(spec.start[0], spec.start[1], spec.width)
            outcomes.append((spec.wind_prob / 3, -100.0, ns))
        else:
            ns = idx(nx, ny, spec.width)
            outcomes.append((spec.wind_prob / 3, -1.0, ns))
    
    return outcomes

def get_joint_outcomes(s, a1, a2, spec):
    x, y = xy(s, spec.width)
    
    if (x, y) == spec.goal:
        return [(1.0, 0.0, s, 0.0, s)]
    
    outcomes = []
    
    # Case 1: No wind (both actions execute as intended)
    nx1, ny1 = apply_action(x, y, a1, spec)
    r1 = -100.0 if is_cliff(nx1, ny1) else -1.0
    ns1 = idx(spec.start[0], spec.start[1], spec.width) if is_cliff(nx1, ny1) else idx(nx1, ny1, spec.width)
    
    nx2, ny2 = apply_action(x, y, a2, spec)
    r2 = -100.0 if is_cliff(nx2, ny2) else -1.0
    ns2 = idx(spec.start[0], spec.start[1], spec.width) if is_cliff(nx2, ny2) else idx(nx2, ny2, spec.width)
    
    outcomes.append((1 - spec.wind_prob, r1, ns1, r2, ns2))
    
    # Case 2: Slip occurs - both actions get replaced by the same random action
    for wind_action in [0, 1, 2, 3]:
        # When wind blows, both intended actions are replaced by wind_action
        nx, ny = apply_action(x, y, wind_action, spec)
        r = -100.0 if is_cliff(nx, ny) else -1.0
        ns = idx(spec.start[0], spec.start[1], spec.width) if is_cliff(nx, ny) else idx(nx, ny, spec.width)
        
        # Both actions result in the same outcome due to shared wind
        outcomes.append((spec.wind_prob / 4, r, ns, r, ns))
    
    return outcomes


def solve_qp_simplex(mu, Sigma, lam, tol=1e-6, max_iter=100):
    n = mu.size
    one = np.ones(n, dtype=mu.dtype)

    L = np.linalg.cholesky(Sigma)

    def chol_solve(L, b):
        y = np.linalg.solve(L, b)
        return np.linalg.solve(L.T, y)

    # Unconstrained KKT for equality and quadratic term
    c1 = chol_solve(L, one)         # Sigma^{-1} * 1
    c2 = chol_solve(L, mu)          # Sigma^{-1} * mu
    nu_unc = (np.dot(one, c2) - 2.0*lam) / np.dot(one, c1)
    w_unc = (c2 - nu_unc * c1) / (2.0 * lam)

    S = set(np.where(w_unc > tol)[0].tolist()) or {int(np.argmax(mu))}

    for _ in range(max_iter):
        idx = np.array(sorted(S), dtype=int)
        k = idx.size

        SigSS = Sigma[np.ix_(idx, idx)]

        # Build KKT system:
        KKT = np.zeros((k + 1, k + 1), dtype=mu.dtype)
        KKT[:k, :k] = 2.0 * lam * SigSS
        KKT[:k,  k] = 1.0
        KKT[  k, :k] = 1.0

        rhs = np.concatenate([mu[idx], np.array([1.0], dtype=mu.dtype)])

        sol = np.linalg.solve(KKT, rhs)
        wS, nu = sol[:k], sol[k]

        # Enforce nonnegativity on active set
        if np.any(wS <= tol):
            mask = wS > tol
            keep = set(idx[mask].tolist())
            if keep == S:
                # Drop the smallest (most negative) entry to make progress
                to_drop = int(idx[np.argmin(wS)])
                keep.discard(to_drop)
            S = keep
            continue

        # Build full w
        w = np.zeros(n, dtype=mu.dtype)
        w[idx] = wS

        # Reduced-cost / stationarity violation for inactive coords
        r = mu - 2.0 * lam * (Sigma @ w) - nu * one

        viol = [i for i in range(n) if i not in S and r[i] > tol]
        if not viol:
            return w

        # Add most violating index to active set
        S.add(int(max(viol, key=lambda i: float(r[i]))))

    raise RuntimeError("did not converge")


def get_markowitz_policy(spec, means, covs, lam):
    nS = spec.width * spec.height
    policy = np.zeros(shape=(nS, 4))
    
    for s in range(nS):
        x, y = xy(s, spec.width)
        
        if (x, y) == spec.goal:
            policy[s] = np.array([1.0, 0.0, 0.0, 0.0])
        else:
            policy[s] = solve_qp_simplex(means[s], covs[s], lam)
    
    return policy


def get_unsafe_policy(spec):
    """
    Get an unsafe policy for CliffWalking that walks along the cliff.
    """
    nS = spec.width * spec.height
    policy = np.zeros(nS, dtype=int)
    
    for s in range(nS):
        x, y = xy(s, spec.width)
        
        if (x, y) == spec.goal:
            policy[s] = 0  # Doesn't matter, terminal state
        elif y == 1:
            # Keep going right
            policy[s] = 1 # Right
        elif y > 1:
            # Too high, go down
            policy[s] = 2  # Down  
        else:
            policy[s] = 0
    return policy

def get_optimal_policy(spec):
    """
    Get a safe policy for CliffWalking that avoids the cliff.
    This goes up first, then right, then down to goal.
    """
    nS = spec.width * spec.height
    policy = np.zeros(nS, dtype=int)
    
    for s in range(nS):
        x, y = xy(s, spec.width)
        
        if (x, y) == spec.goal:
            policy[s] = 0  # Doesn't matter, terminal state
        elif y < 2:
            # On bottom row, go up to avoid cliff
            policy[s] = 0  # Up
        elif x < spec.goal[0]:
            # On safe row, go right toward goal
            policy[s] = 1  # Right
        else:
            # Above goal, go down
            policy[s] = 2  # Down
    
    return policy

def build_P_R_for_policy(spec, policy):
    # Build transition matrix P and reward vector R for a given policy.
    nS = spec.width * spec.height
    P = np.zeros((nS, nS))
    R = np.zeros(nS)
    
    for s in range(nS):
        a = policy[s]
        outcomes = get_transition_outcomes(s, a, spec)
        for prob, reward, ns in outcomes:
            P[s, ns] += prob
            R[s] += prob * reward
    
    return P, R

def solve_ground_truth(spec, policy):
    # Analytically solve for V, S2, and Cross using linear algebra.
    g = spec.gamma
    P, R = build_P_R_for_policy(spec, policy)
    nS = P.shape[0]
    I = np.eye(nS)
    
    # Value function
    V = np.linalg.solve(I - g * P, R)
    
    # Uncentered second moment S2
    M = np.zeros_like(R)
    for s in range(nS):
        if is_terminal(s, spec):
            M[s] = 0.0
        else:
            a = policy[s]
            outcomes = get_transition_outcomes(s, a, spec)
            M[s] = sum(prob * reward * V[ns] for prob, reward, ns in outcomes)
    
    # Compute E[R^2] for each state
    R2 = np.zeros_like(R)
    for s in range(nS):
        a = policy[s]
        outcomes = get_transition_outcomes(s, a, spec)
        R2[s] = sum(prob * reward**2 for prob, reward, ns in outcomes)
    
    S2_rhs = R2 + 2 * g * M
    S2 = np.linalg.solve(I - (g**2) * P, S2_rhs)
    
    # Cross covariance between state pairs (with shared slip)
    n2 = nS * nS
    Pj = np.zeros((n2, n2))
    B = np.zeros(n2)
    
    for s1 in range(nS):
        a1 = policy[s1]
        for s2 in range(nS):
            a2 = policy[s2]
            k = s1 * nS + s2
            
            if is_terminal(s1, spec) and is_terminal(s2, spec):
                Pj[k, k] = 1.0
                B[k] = 0.0
            elif is_terminal(s1, spec):
                # s1 terminal, s2 not
                outcomes = get_transition_outcomes(s2, a2, spec)
                for prob, r2, ns2 in outcomes:
                    k2 = s1 * nS + ns2
                    Pj[k, k2] += prob
                    B[k] += prob * g * r2 * V[s1]
            elif is_terminal(s2, spec):
                # s2 terminal, s1 not  
                outcomes = get_transition_outcomes(s1, a1, spec)
                for prob, r1, ns1 in outcomes:
                    k2 = ns1 * nS + s2
                    Pj[k, k2] += prob
                    B[k] += prob * g * r1 * V[s2]
            else:
                # Neither terminal, use joint outcomes with shared wind
                if s1 == s2:
                    # Same states' outcomes are perfectly correlated
                    outcomes = get_transition_outcomes(s1, a1, spec)
                    for prob, r, ns in outcomes:
                        k2 = ns * nS + ns
                        Pj[k, k2] += prob
                        B[k] += prob * (r**2 + 2 * g * r * V[ns])
                else:
                    # Different states have independent transitions
                    outcomes1 = get_transition_outcomes(s1, a1, spec)
                    outcomes2 = get_transition_outcomes(s2, a2, spec)
                    for p1, r1, ns1 in outcomes1:
                        for p2, r2, ns2 in outcomes2:
                            k2 = ns1 * nS + ns2
                            Pj[k, k2] += p1 * p2
                            B[k] += p1 * p2 * (r1 * r2 + g * r1 * V[ns2] + g * r2 * V[ns1])
    
    Cross_flat = np.linalg.solve(np.eye(n2) - (g**2) * Pj, B)
    Cross = Cross_flat.reshape((nS, nS))
    
    return V, S2, Cross

def jipe_iterative(spec, policy, tol=1e-10, max_it=1000):
    g = spec.gamma
    P, R = build_P_R_for_policy(spec, policy)
    nS = P.shape[0]
    
    # Value iteration
    V = np.zeros(nS)
    for _ in range(max_it):
        Vn = R + g * (P @ V)
        if np.max(np.abs(Vn - V)) < tol:
            V = Vn
            break
        V = Vn
    
    # S2 iteration
    S2 = np.zeros(nS)
    for _ in range(max_it):
        S2n = np.zeros_like(S2)
        for s in range(nS):
            if is_terminal(s, spec):
                S2n[s] = 0.0
            else:
                a = policy[s]
                outcomes = get_transition_outcomes(s, a, spec)
                S2n[s] = sum(prob * (r**2 + 2*g*r*V[ns] + (g**2)*S2[ns]) 
                           for prob, r, ns in outcomes)
        if np.max(np.abs(S2n - S2)) < tol:
            S2 = S2n
            break
        S2 = S2n
    
    # Cross iteration
    Cross = np.zeros((nS, nS))
    for _ in range(max_it):
        Cn = np.zeros_like(Cross)
        for s1 in range(nS):
            a1 = policy[s1]
            for s2 in range(nS):
                a2 = policy[s2]
                
                if is_terminal(s1, spec) and is_terminal(s2, spec):
                    Cn[s1, s2] = 0.0
                elif is_terminal(s1, spec):
                    outcomes = get_transition_outcomes(s2, a2, spec)
                    Cn[s1, s2] = sum(prob * g * r2 * V[s1] 
                                   for prob, r2, ns2 in outcomes)
                elif is_terminal(s2, spec):
                    outcomes = get_transition_outcomes(s1, a1, spec)
                    Cn[s1, s2] = sum(prob * g * r1 * V[s2]
                                   for prob, r1, ns1 in outcomes)
                elif s1 == s2:
                    # Same state is perfectly correlated
                    outcomes = get_transition_outcomes(s1, a1, spec)
                    Cn[s1, s2] = sum(prob * (r**2 + 2*g*r*V[ns] + (g**2)*Cross[ns, ns])
                                   for prob, r, ns in outcomes)
                else:
                    # Different states are independent
                    outcomes1 = get_transition_outcomes(s1, a1, spec)
                    outcomes2 = get_transition_outcomes(s2, a2, spec)
                    exp = 0.0
                    for p1, r1, ns1 in outcomes1:
                        for p2, r2, ns2 in outcomes2:
                            exp += p1*p2*(r1*r2 + g*r1*V[ns2] + g*r2*V[ns1] + (g**2)*Cross[ns1, ns2])
                    Cn[s1, s2] = exp
        
        if np.max(np.abs(Cn - Cross)) < tol:
            Cross = Cn
            break
        Cross = Cn
    
    return V, S2, Cross

def per_action_moments_at_state(s, spec, V, S2, Cross):
    g = spec.gamma
    nA = 4
    
    # Expected returns for each action
    mu = np.zeros(nA)
    for a in range(nA):
        outcomes = get_transition_outcomes(s, a, spec)
        mu[a] = sum(prob * (r + g * V[ns]) for prob, r, ns in outcomes)
    
    # Second moment matrix E[Q_i * Q_j]
    M2 = np.zeros((nA, nA))
    
    # Diagonal elements E[Q_i^2]
    for a in range(nA):
        outcomes = get_transition_outcomes(s, a, spec)
        second_moment = 0.0
        for prob, r, ns in outcomes:
            # Immediate reward squared term
            second_moment += prob * r**2
            # Cross term
            second_moment += prob * 2 * g * r * V[ns]
            # Future value squared term
            second_moment += prob * (g**2) * S2[ns]
        M2[a, a] = second_moment
    
    # Off-diagonal elements E[Q_i * Q_j] with shared wind
    for a1 in range(nA):
        for a2 in range(a1 + 1, nA):
            joint_outcomes = get_joint_outcomes(s, a1, a2, spec)
            second_moment = 0.0
            
            for prob, r1, ns1, r2, ns2 in joint_outcomes:
                second_moment += prob * r1 * r2
                second_moment += prob * g * r1 * V[ns2]
                second_moment += prob * g * r2 * V[ns1]
                # For the cross term of future values
                if ns1 == ns2:
                    # Same state - use S2
                    second_moment += prob * (g**2) * S2[ns1]
                else:
                    # Different states - use Cross
                    second_moment += prob * (g**2) * Cross[ns1, ns2]
            
            M2[a1, a2] = second_moment
            M2[a2, a1] = second_moment
    
    # Covariance matrix: Cov = E[Q*Q'] - E[Q]*E[Q]'
    Sigma = M2 - np.outer(mu, mu)
    
    # Ensure numerical stability
    Sigma = (Sigma + Sigma.T) / 2  # Force symmetry
    
    # Check for positive semi-definiteness
    eigvals = np.linalg.eigvalsh(Sigma)
    if np.min(eigvals) < 1e-6:
        # Project to nearest PSD matrix
        eigvals = np.maximum(eigvals, 1e-6)
        eigvecs = np.linalg.eigh(Sigma)[1]
        Sigma = eigvecs @ np.diag(eigvals) @ eigvecs.T
    
    return mu, Sigma

def compute_all_state_moments(spec, V, S2, Cross):
    # Compute expected returns and covariance matrices for all states.
    nS = spec.width * spec.height
    all_mu = {}
    all_cov = {}
    
    for s in range(nS):
        mu, cov = per_action_moments_at_state(s, spec, V, S2, Cross)
        all_mu[s] = mu
        all_cov[s] = cov
    
    return all_mu, all_cov

def run_slippery_cliffwalking_jipe(wind_prob=0.1, gamma=0.95, use_iterative=True, print_sample=True):
    spec = SlipperyCliffSpec(wind_prob=wind_prob, gamma=gamma)
    
    # Get unsafe policy
    policy = get_unsafe_policy(spec)
    
    # Compute V, S2, Cross
    if use_iterative:
        print(f"Computing JIPE (wind_prob={wind_prob})...")
        V, S2, Cross = jipe_iterative(spec, policy)
    else:
        print(f"Computing analytically (wind_prob={wind_prob})...")
        V, S2, Cross = solve_ground_truth(spec, policy)
    
    # Compute per-action moments for all states
    all_mu, all_cov = compute_all_state_moments(spec, V, S2, Cross)
    
    if print_sample:
        # Print results for a few sample states
        sample_states = [
            (0, 0, "Start"),
            (5, 1, "Safe path middle"),
            (10, 0, "Near goal (cliff edge)"),
            (11, 0, "Goal")
        ]
    
    return all_mu, all_cov

def simulate_policy(policy, spec, n_episodes=100, max_steps=200, stochastic=False):
    """
    Simulate a policy on the slippery CliffWalking environment.
    
    Args:
        policy: If stochastic=False: array of shape (48,) with action index for each state
               If stochastic=True: array of shape (48, 4) with action probabilities for each state
        spec: SlipperyCliffSpec object defining the environment
        n_episodes: Number of episodes to simulate
        max_steps: Maximum steps per episode
        stochastic: Whether the policy is stochastic
        
    Returns:
        mean_reward: Average cumulative reward across episodes
        sem: Standard error of the mean
    """
    episode_rewards = []
    
    for _ in range(n_episodes):
        # Start state
        s = idx(spec.start[0], spec.start[1], spec.width)
        total_reward = 0.0
        
        for step in range(max_steps):
            # Check if terminal
            if is_terminal(s, spec):
                break
            
            # Select action
            if stochastic:
                # Sample from probability distribution
                probs = policy[s]
                action = np.random.choice(4, p=probs)
            else:
                # Deterministic policy
                action = policy[s]
            
            # Get possible outcomes with slippery dynamics
            outcomes = get_transition_outcomes(s, action, spec)
            
            # Sample outcome
            probs = [p for p, r, ns in outcomes]
            outcome_idx = np.random.choice(len(outcomes), p=probs)
            _, reward, next_state = outcomes[outcome_idx]
            
            # Apply discount and accumulate reward
            total_reward += (spec.gamma ** step) * reward
            s = next_state
        
        episode_rewards.append(total_reward)
    
    # Calculate statistics
    episode_rewards = np.array(episode_rewards)
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards, ddof=1)  # Sample standard deviation
    min_reward = np.min(episode_rewards)
    max_reward = np.max(episode_rewards)
    # sem = std_reward / np.sqrt(n_episodes)  # Standard error of the mean
    
    return mean_reward, std_reward, min_reward, max_reward

if __name__ == "__main__":
    # Run main analysis
    all_mu, all_cov = run_slippery_cliffwalking_jipe(
        wind_prob=0.5, gamma=0.95, use_iterative=True, print_sample=True
    )

    spec = SlipperyCliffSpec(wind_prob=0.5, gamma=0.95)
    unsafe_policy = get_unsafe_policy(spec)
    opt_policy = get_optimal_policy(spec)
    m, s, mi, ma = simulate_policy(unsafe_policy, spec, n_episodes=2500)
    print("Mean of unsafe policy:", m)
    print("std of unsafe policy:", s)
    print("min of unsafe policy:", mi)
    print("max of unsafe policy:", ma)

    lambdas = [0.01, 0.1, 1.0, 10.0, 100.0]
    for lam in lambdas:
        markowitz_policy = get_markowitz_policy(spec, all_mu, all_cov, lam=lam)

        m, s, mi, ma = simulate_policy(markowitz_policy, spec, stochastic=True, n_episodes=2500)
        print(f"Mean of markowitz policy with lambda={lam}:", m)
        print(f"std of markowitz policy with lambda={lam}:", s)
        print(f"min of markowitz policy with lambda={lam}:", mi)
        print(f"max of markowitz policy with lambda={lam}:", ma)