"""
Shared-Randomness Bandit + JIPE Demo
---------------------
Defines a one-state K-armed bandit where, at each time step, the reward
vector R_t is i.i.d. across time with mean mu (shape [K]) and
covariance Sigma (shape [K,K]). Different actions at the same time can be
correlated via Sigma (shared randomness).
Evaluates a fixed policy pi that always plays a chosen action a_pi.
Computes analytic ground truth for per-action means and covariances
of the return.
Computes JIPE estimates by iterating until convergence.

"""

import numpy as np

class BanditSpec:
    mu: np.ndarray          # shape (K,)
    Sigma: np.ndarray       # shape (K,K)
    gamma: float
    policy_action: int

def _check_spec(spec: BanditSpec):
    mu = np.asarray(spec.mu, dtype=float).reshape(-1)
    Sigma = np.asarray(spec.Sigma, dtype=float)
    K = mu.shape[0]

    return mu, Sigma, K

# Analytic ground truth
def analytic_per_action(spec: BanditSpec):
    mu, Sigma, K = _check_spec(spec)
    g = spec.gamma
    ap = spec.policy_action

    muW = mu[ap] / (1 - g)
    varW = Sigma[ap,ap] / (1 - g**2)
    m2W = varW + muW**2

    mu_true = mu + g * muW
    Sbar_true = np.empty((K,K))

    for a in range(K):
        E_R2_a = Sigma[a,a] + mu[a]**2
        m2_a = E_R2_a + 2*g*mu[a]*muW + g**2*m2W
        Sbar_true[a,a] = m2_a

    for a in range(K):
        for b in range(a+1, K):
            c_ab = (Sigma[a,b] + mu[a]*mu[b]) + g*(mu[a]+mu[b])*muW + g**2*m2W
            Sbar_true[a,b] = Sbar_true[b,a] = c_ab

    Sigma_true = Sbar_true - np.outer(mu_true, mu_true)
    return mu_true, Sbar_true, Sigma_true

# JIPE
def jipe_per_action(spec: BanditSpec, tol=1e-12, max_it=100000):
    mu, Sigma, K = _check_spec(spec)
    g = spec.gamma
    ap = spec.policy_action

    V, S2 = 0.0, 0.0
    E_R2_pi = Sigma[ap,ap] + mu[ap]**2

    for _ in range(max_it):
        V_new = mu[ap] + g*V
        S2_new = E_R2_pi + 2*g*mu[ap]*V + g**2*S2
        if max(abs(V_new - V), abs(S2_new - S2)) < tol:
            V, S2 = V_new, S2_new
            break
        V, S2 = V_new, S2_new

    mu_est = mu + g*V
    Sbar_est = np.empty((K,K))

    for a in range(K):
        E_R2_a = Sigma[a,a] + mu[a]**2
        m2_a = E_R2_a + 2*g*mu[a]*V + g**2*S2
        Sbar_est[a,a] = m2_a

    for a in range(K):
        for b in range(a+1, K):
            c_ab = (Sigma[a,b] + mu[a]*mu[b]) + g*(mu[a]+mu[b])*V + g**2*S2
            Sbar_est[a,b] = Sbar_est[b,a] = c_ab

    Sigma_est = Sbar_est - np.outer(mu_est, mu_est)
    return mu_est, Sbar_est, Sigma_est

def covariance_to_correlation(cov_matrix):
    cov_matrix = np.array(cov_matrix)
    
    # Extract diagonal elements (variances)
    variances = np.diag(cov_matrix)

    # Calculate standard deviations
    std_devs = np.sqrt(variances)
    
    # Create correlation matrix: corr_ij = cov_ij / (std_i * std_j)
    correlation_matrix = cov_matrix / np.outer(std_devs, std_devs)
    
    return correlation_matrix


def run_demo(mu=(0.0,0.2), Sigma=((0.8,0.6),(0.6,1.0)), gamma=.90, policy_action=1):
    mu = np.array(mu, dtype=float)
    Sigma = np.array(Sigma, dtype=float)
    spec = BanditSpec(mu=mu, Sigma=Sigma, gamma=gamma, policy_action=policy_action)

    mu_true, Sbar_true, Sig_true = analytic_per_action(spec)
    mu_est, Sbar_est, Sig_est = jipe_per_action(spec)

    Sig_true = covariance_to_correlation(Sig_true)
    Sig_est = covariance_to_correlation(Sig_est)


    print("Shared-Randomness Bandit JIPE Demo")
    print("----------------------------------")
    print("Policy evaluated: always take action a_pi =", policy_action)
    print("mu =", mu)
    print("Sigma =\n", Sigma)
    print("gamma =", gamma, "\n")

    print("--- Means ---")
    print("True:", mu_true)
    print("JIPE:", mu_est)
    print("Max absolute error =", np.max(np.abs(mu_true - mu_est)))

    print("\n--- Correlation ---")
    print("True:\n", Sig_true)
    print("JIPE:\n", Sig_est)
    print("Max absolute error =", np.max(np.abs(Sig_true - Sig_est)))

if __name__ == "__main__":
    run_demo()
