import metaworld
import gymnasium as gym
import numpy as np
import numpy as np
from collections import defaultdict
import statsmodels.stats.proportion as smp



def _to_batch_obs(obs):
    if isinstance(obs, dict):
        return {k: np.expand_dims(v, 0).astype(np.float32) for k, v in obs.items()}
    import numpy as _np
    return _np.expand_dims(obs, 0).astype(_np.float32)

def eval_policy(policy, env_name: str, episodes: int, max_steps: int, seed: int):
    # rely on the same registry ID you use everywhere else
    env = gym.make("Meta-World/MT1", env_name=env_name, seed=seed, render_mode=None)
    succ, returns = 0, []
    for ep in range(episodes):
        obs, info = env.reset(seed=seed + ep)
        ep_ret = 0.0
        for _ in range(max_steps):
            action = policy.predict(_to_batch_obs(obs))[0]
            obs, rew, term, trunc, info = env.step(action)
            ep_ret += float(rew)
            if info.get("success", False):
                succ += 1
                break
            if term or trunc:
                break
        returns.append(ep_ret)
    env.close()
    return succ / episodes, float(np.mean(returns)), returns



from collections import defaultdict
import numpy as np

# ---- helpers (no external deps) ----
def binomial_ci_wilson(k: int, n: int, alpha: float = 0.05):
    if n == 0:
        return 0.0, 1.0
    # z for 95% by default; if you change alpha, update z accordingly.
    z = 1.959963984540054  # ~ N^{-1}(1 - alpha/2) for alpha=0.05
    phat = k / n
    denom = 1.0 + (z*z)/n
    centre = (phat + (z*z)/(2*n)) / denom
    margin = (z / denom) * np.sqrt((phat*(1 - phat) + (z*z)/(4*n)) / n)
    lo = max(0.0, centre - margin)
    hi = min(1.0, centre + margin)
    return lo, hi

def interquartile_mean(x: np.ndarray) -> float:
    x = np.sort(np.asarray(x, dtype=np.float64))
    n = len(x)
    if n == 0:
        return float("nan")
    q1 = int(np.floor(0.25 * n))
    q3 = int(np.ceil(0.75 * n))
    return float(x[q1:q3].mean()) if q3 > q1 else float(x.mean())

def bootstrap_iqm_ci(x: np.ndarray, n_boot: int = 2000, alpha: float = 0.05, seed: int = 0):
    rng = np.random.default_rng(seed)
    boots = []
    n = len(x)
    for _ in range(n_boot):
        samp = x[rng.integers(0, n, size=n)]
        boots.append(interquartile_mean(samp))
    lo = float(np.percentile(boots, 100.0 * (alpha / 2.0)))
    hi = float(np.percentile(boots, 100.0 * (1.0 - alpha / 2.0)))
    return lo, hi

def eval_policies_across_envs(policies: dict, env_names: list, episodes: int, max_steps: int, seed: int):
    table = []
    for env in env_names:
        for name, pol in policies.items():
            sr, avg_ret, _ = eval_policy(pol, env, episodes=episodes, max_steps=max_steps, seed=seed)
            k = int(round(sr * episodes))
            lo, hi = binomial_ci_wilson(k, episodes)
            print(f"{env:20s} {name:10s} {sr*100:6.2f}% "
                  f"(95% CI [{lo*100:5.1f}, {hi*100:5.1f}]) "
                  f"AvgRet={avg_ret:8.2f}")
            table.append((env, name, sr, avg_ret, lo, hi))

    # ---- Aggregate across envs (per reward) ----
    per_policy_sr = defaultdict(list)
    for env, name, sr, _, _, _ in table:
        per_policy_sr[name].append(sr)

    print("\n=== Aggregated across envs (success rate) ===")
    policy_stats = {}
    for name, srs in per_policy_sr.items():
        srs = np.asarray(srs, dtype=np.float64)
        mean = float(srs.mean())
        iqm = interquartile_mean(srs)
        lo, hi = bootstrap_iqm_ci(srs, n_boot=2000, alpha=0.05, seed=seed)
        policy_stats[name] = {"mean": mean, "iqm": iqm, "iqm_lo": lo, "iqm_hi": hi}
        print(f"reward={name:12s} MEAN={mean*100:5.1f}%  "
              f"IQM={iqm*100:5.1f}% (95% CI [{lo*100:5.1f}, {hi*100:5.1f}])  "
              f"over {len(srs)} envs")

    return table, policy_stats
