##
## (c) Anonymous authors (2026)
##
## > Script to evaluate residual-based informativeness
##
##

import itertools
import multiprocessing
import os
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from scipy.spatial.distance import pdist, squareform
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
from tqdm import tqdm

from informativeness.synthetic_informed_POMDP.envs import Episode, RandomSyntheticInformedPOMDP
from informativeness.synthetic_informed_POMDP.models import RecurrentCritic

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["WANDB_MODE"] = os.getenv("WANDB_MODE", "online")
os.environ["WANDB_SILENT"] = os.getenv("WANDB_SILENT", "true")

GAMMA = 0.99

# History-encoding network parameters
CRITIC_HIDDEN = 64
CRITIC_EPOCHS = 1
CRITIC_LR = 1e-4

# Regression and evaluation parameters
RF_TREES = 100
N_SPLITS = 5
HSIC_B = 1000

DEVICE = torch.device("cpu")


def collect_episodes(env, n_episodes, horizon):
    episodes = []
    for _ in range(n_episodes):
        s = env.reset()
        states, obs, actions, rewards = [], [], [], []

        for t in range(horizon):
            o = env.generate_obs(s)
            a = np.random.randint(env.A)
            r, s_next = env.step(a)

            states.append(s)
            obs.append(o)
            actions.append(a)
            rewards.append(r)

            s = s_next

        # discounted returns
        G = []
        ret = 0.0
        for r in reversed(rewards):
            ret = r + GAMMA * ret
            G.insert(0, ret)

        episodes.append(Episode(states, obs, actions, rewards, G))
    return episodes


def train_td_critic(episodes, obs_dim, action_dim):
    critic = RecurrentCritic(obs_dim, action_dim, CRITIC_HIDDEN).to(DEVICE)
    optimizer = optim.Adam(critic.parameters(), lr=CRITIC_LR)
    loss_fn = nn.MSELoss()

    critic.train()
    for epoch in range(CRITIC_EPOCHS):
        random.shuffle(episodes)

        for ep in episodes:
            T = len(ep.obs)

            x = []
            for t in range(T):
                a_onehot = np.zeros(action_dim, dtype=np.float32)
                a_onehot[ep.actions[t]] = 1.0
                x.append(np.concatenate([ep.obs[t], a_onehot]))

            x = torch.from_numpy(np.asarray(x)).float().unsqueeze(0).to(DEVICE)

            values, _ = critic(x)

            targets = []
            for t in range(T):
                r_t = torch.tensor(ep.rewards[t], dtype=values.dtype, device=values.device)

                if t < T - 1:
                    targets.append(r_t + GAMMA * values[0, t + 1].detach())
                else:
                    targets.append(r_t)

            targets = torch.stack(targets)

            loss = loss_fn(values[0], targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return critic


def extract_embeddings(critic, episodes, obs_dim, action_dim):
    critic.eval()
    Z, G, S = [], [], []

    with torch.no_grad():
        for ep in episodes:
            x = []
            for t in range(len(ep.obs)):
                a_onehot = np.zeros(action_dim, dtype=np.float32)
                a_onehot[ep.actions[t]] = 1.0
                x.append(np.concatenate([ep.obs[t], a_onehot]))

            x = torch.from_numpy(np.asarray(x)).float().unsqueeze(0).to(DEVICE)
            _, h_seq = critic(x)
            h_seq = h_seq.squeeze(0).cpu().numpy()

            for t in range(len(ep.obs)):
                Z.append(h_seq[t])
                G.append(ep.returns[t])
                S.append(ep.states[t])

    return np.array(Z), np.array(G), S


def compute_i_from_s(states, env, info_type, seed):
    if info_type == "noise":
        rng = np.random.default_rng(seed)
        latent_dim = 2
        return rng.normal(size=(len(states), latent_dim))
    else:
        privileged_information = np.stack([env.generate_latent(s)[info_type] for s in states], axis=0)
        if privileged_information.shape[1] == 1:
            privileged_information = privileged_information.ravel()
        return privileged_information


def generate_fixed_env(rnd_seed):
    num_states = 20
    latent_dim = 5
    obs_dim = 2
    num_actions = 4

    # shared seed to initialize state feature and transitions
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    fixed_latents = np.random.randn(num_states, latent_dim)

    reward_weights = [0.0001, 0.0001, -0.0001, -1.0, 1.0]

    transitions = np.zeros((num_states, num_actions, num_states))
    for s in range(num_states):
        for a in range(num_actions):
            for s_prime in range(num_states):
                rnd_prob = np.random.rand()
                if rnd_prob < 0.25:
                    transitions[s, a, s_prime] = np.random.uniform(0, 1)
            if np.sum(transitions[s, a]) == 0:
                transitions[s, a, np.random.randint(num_states)] = 1.0
            transitions[s, a] /= np.sum(transitions[s, a])

    return RandomSyntheticInformedPOMDP(num_states=num_states, obs_dim=obs_dim, latent_dim=latent_dim,
                                        num_actions=num_actions, fixed_latent_map=fixed_latents,
                                        transitions=transitions, reward_weights=reward_weights, seed=rnd_seed)


def crossfit_residuals_episodewise(X, Z, n_episodes, horizon, seed, n_splits):
    """
    Compute cross-fitted baseline predictions episode-wise.
    """
    sample_episode = np.repeat(np.arange(n_episodes), horizon)
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)

    x_hat = np.zeros_like(X)

    for train_eps, test_eps in kf.split(np.arange(n_episodes)):
        train_mask = np.isin(sample_episode, train_eps)
        test_mask = np.isin(sample_episode, test_eps)

        model = RandomForestRegressor(
            n_estimators=RF_TREES,
            random_state=seed
        )

        model.fit(Z[train_mask], X[train_mask])
        x_hat[test_mask] = model.predict(Z[test_mask])

    r = X - x_hat
    mse_X = mean_squared_error(X, x_hat)

    return r, x_hat, mse_X, kf, sample_episode


def rbf_kernel(X, sigma):
    X = np.asarray(X)

    if X.ndim == 1:
        X = X[:, None]

    n = X.shape[0]

    if n <= 1:
        return np.ones((n, n))

    dists = squareform(pdist(X, metric="sqeuclidean"))
    K = np.exp(-dists / (2 * sigma ** 2))
    return K


def median_heuristic(X, subsample=1000):
    X = np.atleast_2d(X)
    if X.shape[0] == 1:
        X = X.T

    n = X.shape[0]
    if n > subsample:
        idx = np.random.choice(n, subsample, replace=False)
        X = X[idx]

    dists = pdist(X, metric="euclidean")
    med = np.median(dists)
    return max(med, 1e-6)


def nystrom_features(X, sigma, m, rng):
    """
    Compute Nystroem feature map for RBF kernel.
    """
    X = np.asarray(X)
    if X.ndim == 1:
        X = X[:, None]

    n = X.shape[0]
    m = min(m, n)

    # Sample landmarks
    idx = rng.choice(n, size=m, replace=False)
    X_land = X[idx]

    K_nm = np.exp(
        -squareform(pdist(np.vstack([X, X_land]), "sqeuclidean"))[:n, n:]
        / (2 * sigma ** 2)
    )

    K_mm = np.exp(
        -squareform(pdist(X_land, "sqeuclidean"))
        / (2 * sigma ** 2)
    )

    # Stabilize & invert sqrt 
    eigvals, eigvecs = np.linalg.eigh(K_mm)
    eigvals = np.maximum(eigvals, 1e-8)
    K_mm_inv_sqrt = eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T

    # Compute feature map 
    Phi = K_nm @ K_mm_inv_sqrt
    return Phi


def center_features(Phi):
    return Phi - Phi.mean(axis=0, keepdims=True)


def nystrom_hsic(Phi_X, Phi_Y):
    n = Phi_X.shape[0]
    M = Phi_X.T @ Phi_Y
    return np.sum(M ** 2) / ((n - 1) ** 2)


def residual_based_test_nystrom(
        G_res,
        I_res,
        Z,
        n_episodes,
        horizon,
        seed=0,
        m=256,
        num_permutations=HSIC_B
):
    rng = np.random.default_rng(seed)

    # Compute kernel bandwidths
    bw_G = median_heuristic(G_res)
    bw_I = median_heuristic(I_res)

    # Compute Nystrom features
    Phi_G_full = nystrom_features(G_res, bw_G, m, rng)
    Phi_I_full = nystrom_features(I_res, bw_I, m, rng)
    Phi_I = center_features(Phi_I_full)

    # Compute HSIC under H_1
    hsic_obs = nystrom_hsic(center_features(Phi_G_full), Phi_I)

    null_stats = np.zeros(num_permutations)
    G_feat_reshaped = Phi_G_full.reshape(n_episodes, horizon, -1)

    # Permutation test (episode-wise)
    for b in range(num_permutations):
        perm = rng.permutation(n_episodes)
        Phi_G_perm = G_feat_reshaped[perm].reshape(-1, Phi_G_full.shape[1])

        Phi_G_perm = center_features(Phi_G_perm)

        # Compute HSIC under H_0
        null_stats[b] = nystrom_hsic(Phi_G_perm, Phi_I)

    # Get p-value
    p_value = (1 + np.sum(null_stats >= hsic_obs)) / (num_permutations + 1)

    return {
        "hsic": hsic_obs,
        "p_value": p_value,
        "null_mean": null_stats.mean(),
        "null_std": null_stats.std(),
    }


def chunks(lst, n):
    """Create successive n-sized chunks from list."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def run_worker_and_write(args):
    n_train_eps, n_eval_eps, horizon, info_type, N_SPLITS, seed, m = args

    wandb.init(
        entity="ENTITY_NAME",
        project="PROJECT_NAME",
        group="GROUP_NAME",
        config={
            "n_train_eps": n_train_eps,
            "n_eval_eps": n_eval_eps,
            "horizon": horizon,
            "info_type": info_type,
            "N_SPLITS": N_SPLITS,
            "n_RF_trees": RF_TREES,
            "B": HSIC_B,
            "seed": seed,
            "nystrom_m": m,
            "critic_hidden": CRITIC_HIDDEN,
            "critic_lr": CRITIC_LR,
            "gamma": GAMMA,
        },
        reinit=True,
        name=f"seed={seed}_info={info_type}",
    )

    env = generate_fixed_env(rnd_seed=seed)
    train_eps = collect_episodes(env, n_train_eps, horizon)
    critic = train_td_critic(train_eps, env.obs_dim, env.A)
    eval_eps = collect_episodes(env, n_eval_eps, horizon)
    Z, G, S = extract_embeddings(critic, eval_eps, env.obs_dim, env.A)
    G_res, ghat_h, mse_h, kf, sample_episode = crossfit_residuals_episodewise(G, Z, n_eval_eps, horizon, seed, N_SPLITS)

    i = compute_i_from_s(S, env, info_type, seed)
    i_res, ihat_h, mse_i_h, kf, sample_episode = crossfit_residuals_episodewise(i, Z, n_eval_eps, horizon, seed,
                                                                                N_SPLITS)
    result = residual_based_test_nystrom(G_res, i_res, Z, n_eval_eps, horizon, seed=seed, m=m)

    wandb.log({
        "hsic": result["hsic"],
        "p_value": result["p_value"],
        "null_mean": result["null_mean"],
        "null_std": result["null_std"],
        "mse_G": mse_h,
        "mse_I": mse_i_h,
    })

    wandb.finish()

    return {"seed": seed, "info_type": info_type, **result}


def run_experiments(seeds, info_types, n_train_eps, n_eval_eps, horizon, m=512, outdir="results",
                    num_workers=None):
    if num_workers is None:
        num_workers = multiprocessing.cpu_count()

    result_files = []

    experiment_configs = [(n_train_eps, n_eval_eps, horizon, info_type, N_SPLITS, seed, m) for info_type in
                          info_types for seed in seeds]

    batch_size = 90
    with tqdm(total=len(experiment_configs)) as pbar:
        for batch in chunks(experiment_configs, batch_size):
            with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
                for result in pool.imap_unordered(run_worker_and_write, batch):
                    result_files.append(result)
                    pbar.update(1)

    # Generating the result dataframe and saving it to .csv file
    os.makedirs(outdir, exist_ok=True)
    final_df = pd.DataFrame(result_files)
    final_df.to_csv(
        f"{outdir}/evaluation_hsic_nystrom_{N_SPLITS}_{len(seeds)}_{len(info_types)}_{n_train_eps}_{n_eval_eps}.csv",
        index=False)

    print("All experiments complete. Results saved.")
    return final_df


if __name__ == "__main__":
    idx = [0, 1, 2, 3, 4]
    required_dims = {0, 1}
    run_experiments(
        seeds=list(range(0, 10)),
        info_types=[list(comb) for r in range(1, len(idx) + 1) for comb in itertools.combinations(idx, r) if
                    required_dims.issubset(comb)],
        n_train_eps=50,
        n_eval_eps=250,
        horizon=25,
        m=512,
        outdir="results",
        num_workers=64
    )
