##
## (c) Anonymous authors (2026)
##
## > Script to evaluate post-hoc informativeness
##
##
import itertools
import multiprocessing
import os
import random
from dataclasses import dataclass

import numpy as np
import pandas as pd
import scipy
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from sklearn.model_selection import KFold
from tqdm import tqdm

from informativeness.synthetic_informed_POMDP.envs import RandomSyntheticInformedPOMDP
from informativeness.synthetic_informed_POMDP.models import RecurrentCritic, AsymmetricRecurrentCritic

os.environ["WANDB_MODE"] = os.getenv("WANDB_MODE", "online")
os.environ["WANDB_SILENT"] = os.getenv("WANDB_SILENT", "true")

GAMMA = 0.99
NOISE_FACTOR = 0.0

# Critic parameters
CRITIC_HIDDEN = 64
CRITIC_EPOCHS = 3
CRITIC_LR = 1e-4

N_SPLITS = 5
B = 1000

DEVICE = torch.device("cpu")


@dataclass
class InformedEpisode:
    states: list
    obs: list
    info: list
    actions: list
    rewards: list
    returns: list


def collect_episodes(env, n_episodes, horizon, info_type=None, seed=None):
    episodes = []
    for _ in range(n_episodes):
        s = env.reset()
        states, obs, actions, rewards = [], [], [], []

        for t in range(horizon):
            o = env.generate_obs(s)
            a = np.random.randint(env.A)
            r, s_next = env.step(a)

            states.append(s)
            obs.append(o)
            actions.append(a)
            rewards.append(r)

            s = s_next

        # discounted returns
        G = []
        ret = 0.0
        for r in reversed(rewards):
            ret = r + GAMMA * ret
            G.insert(0, ret)

        if info_type is not None:
            info = list(compute_i_from_s(states, env, info_type, seed=seed))
        else:
            info = []

        episodes.append(InformedEpisode(states, obs, info, actions, rewards, G))
    return episodes


def train_td_critic(critic, episodes, action_dim, critic_type="symmetric", info_type=None):
    optimizer = optim.Adam(critic.parameters(), lr=CRITIC_LR)
    loss_fn = nn.MSELoss()

    critic.train()
    for epoch in range(CRITIC_EPOCHS):
        random.shuffle(episodes)

        for ep in episodes:
            T = len(ep.obs)

            x = []
            i = []
            for t in range(T):
                a_onehot = np.zeros(action_dim, dtype=np.float32)
                a_onehot[ep.actions[t]] = 1.0
                x.append(np.concatenate([ep.obs[t], a_onehot]))

                if info_type is not None and len(ep.info) > 0:
                    i.append(ep.info[t])

            x = torch.from_numpy(np.asarray(x)).float().unsqueeze(0).to(DEVICE)
            i = torch.from_numpy(np.asarray(i)).float().unsqueeze(0).to(DEVICE)

            if critic_type == "asymmetric":
                values, _ = critic(x, i)
            else:
                values, _ = critic(x)

            targets = []
            for t in range(T):
                r_t = torch.tensor(ep.rewards[t], dtype=values.dtype, device=values.device)

                if t < T - 1:
                    targets.append(r_t + GAMMA * values[0, t + 1].detach())
                else:
                    targets.append(r_t)

            targets = torch.stack(targets)

            loss = loss_fn(values[0], targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return critic


def compute_squared_error_gain(env, episodes, symmetric_critic, asymmetric_critic):
    symmetric_critic.eval()
    asymmetric_critic.eval()

    gains = []

    with torch.no_grad():
        for ep in episodes:
            T = len(ep.obs)
            symmetric_errors, asymmetric_errors = [], []

            for t in range(T):
                a_onehot = np.zeros(env.A, dtype=np.float32)
                a_onehot[ep.actions[t]] = 1.0

                symmetric_input = np.concatenate([ep.obs[t], a_onehot])

                symmetric_input = torch.from_numpy(symmetric_input).float().unsqueeze(0).to(DEVICE)
                i = torch.from_numpy(ep.info[t]).float().unsqueeze(0).to(DEVICE)

                Q_symmetric, _ = symmetric_critic(symmetric_input)
                Q_asymmetric, _ = asymmetric_critic(symmetric_input, i)

                # Get true return
                G_t = ep.returns[t]

                # Compute squared error for each critic
                symmetric_errors.append((Q_symmetric - G_t) ** 2)
                asymmetric_errors.append((Q_asymmetric - G_t) ** 2)

            # Compute mean error gain for this episode
            symmetric_errors = np.mean(symmetric_errors)
            asymmetric_errors = np.mean(asymmetric_errors)
            gain = symmetric_errors - asymmetric_errors
            gains.append(gain)

    return np.array(gains)


def compute_i_from_s(states, env, info_type, seed):
    if info_type == "noise":
        rng = np.random.default_rng(seed)
        latent_dim = 2
        return rng.normal(size=(len(states), latent_dim))
    else:
        privileged_information = np.stack([env.generate_latent(s)[info_type] for s in states], axis=0)
        if privileged_information.shape[1] == 1:
            privileged_information = privileged_information.ravel()
        return privileged_information


def generate_fixed_env(rnd_seed):
    num_states = 20
    latent_dim = 5
    obs_dim = 2
    num_actions = 4

    # shared seed to initialize state feature and transitions
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    fixed_latents = np.random.randn(num_states, latent_dim)

    reward_weights = [0.0001, 0.0001, -0.0001, -1.0, 1.0]

    transitions = np.zeros((num_states, num_actions, num_states))
    for s in range(num_states):
        for a in range(num_actions):
            for s_prime in range(num_states):
                rnd_prob = np.random.rand()
                if rnd_prob < 0.25:
                    transitions[s, a, s_prime] = np.random.uniform(0, 1)
            if np.sum(transitions[s, a]) == 0:
                transitions[s, a, np.random.randint(num_states)] = 1.0
            transitions[s, a] /= np.sum(transitions[s, a])

    return RandomSyntheticInformedPOMDP(num_states=num_states, obs_dim=obs_dim, latent_dim=latent_dim,
                                        num_actions=num_actions, fixed_latent_map=fixed_latents,
                                        transitions=transitions, reward_weights=reward_weights, seed=rnd_seed)


def post_hoc_informativeness_test(gains, N, delta=0.05, n_permutations=B):
    # If N is small, use bootstrap-based test
    if N < 500:
        p = bootstrap_test(gains, n_permutations)
    else:
        p = t_test(gains, N)

    # Reject the null hypothesis if p-value < delta
    reject_null = p < delta
    return reject_null, p


def bootstrap_test(gains, n_permutations):
    # Resample with replacement
    bootstrap_means = []
    for _ in range(n_permutations):
        bootstrap_sample = np.random.choice(gains, size=len(gains), replace=True)
        bootstrap_means.append(np.mean(bootstrap_sample))

    # Compute empirical p-value
    p_value = (1 + np.sum(np.array(bootstrap_means) <= 0)) / (n_permutations + 1)
    return p_value


def t_test(gains, N):
    # Perform a one-sided t-test
    mean_gain = np.mean(gains)
    std_dev = np.std(gains, ddof=1)
    t_statistic = mean_gain / (std_dev / np.sqrt(N))
    wandb.log({"t_statistic": t_statistic})

    # One-sided p-value
    p_value = 1 - scipy.stats.t.cdf(t_statistic, df=N - 1)
    return p_value


def chunks(lst, n):
    """Create successive n-sized chunks from list."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def run_worker_and_write(args):
    n_episodes, horizon, info_type, n_folds, seed = args

    wandb.init(
        entity="ENTITY_NAME",
        project="PROJECT_NAME",
        group="GROUP_NAME",
        config={
            "n_episodes": n_episodes,
            "horizon": horizon,
            "info_type": info_type,
            "N_SPLITS": n_folds,
            "B": B,
            "seed": seed,
            "critic_hidden": CRITIC_HIDDEN,
            "critic_lr": CRITIC_LR,
            "epochs": CRITIC_EPOCHS,
            "gamma": GAMMA,
            "noise_factor": NOISE_FACTOR,
        },
        reinit=True,
        name=f"seed={seed}_info={info_type}",
    )

    env = generate_fixed_env(rnd_seed=seed)
    episodes = collect_episodes(env, n_episodes, horizon, info_type, seed)

    kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=seed)

    symmetric_critic = RecurrentCritic(env.obs_dim, env.A, CRITIC_HIDDEN).to(DEVICE)
    asymmetric_critic = AsymmetricRecurrentCritic(env.obs_dim, env.A, CRITIC_HIDDEN, len(info_type)).to(DEVICE)

    gains = np.zeros(n_episodes)

    for train_eps, test_eps in kf.split(np.arange(n_episodes)):
        symmetric_critic = train_td_critic(symmetric_critic, [episodes[i] for i in train_eps], env.A, "symmetric",
                                           info_type)
        asymmetric_critic = train_td_critic(asymmetric_critic, [episodes[i] for i in train_eps], env.A, "asymmetric",
                                            info_type)

        gains[test_eps] = compute_squared_error_gain(env, [episodes[i] for i in test_eps], symmetric_critic,
                                                     asymmetric_critic)

    reject_null, p_value = post_hoc_informativeness_test(gains, len(gains), delta=0.05, n_permutations=1000)

    gains_mean = np.mean(gains)
    gains_std = np.std(gains)

    wandb.log({
        "p_value": p_value,
        "reject_null": reject_null,
        "gains_mean": gains_mean,
        "gains_std": gains_std,
        "N": len(gains)
    })

    wandb.finish()

    return {"seed": seed, "info_type": info_type, "p_value": p_value, "reject_null": reject_null,
            "gains_mean": gains_mean, "gains_std": gains_std}


def run_experiments(seeds, info_types, n_episodes, horizon, outdir="results",
                    num_workers=None):
    if num_workers is None:
        num_workers = multiprocessing.cpu_count()

    result_files = []

    experiment_configs = [(n_episodes, horizon, info_type, N_SPLITS, seed) for info_type in
                          info_types for seed in seeds]

    batch_size = 80
    with tqdm(total=len(experiment_configs)) as pbar:
        for batch in chunks(experiment_configs, batch_size):
            with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
                for result in pool.imap_unordered(run_worker_and_write, batch):
                    result_files.append(result)
                    pbar.update(1)

    # Generating the result dataframe and saving it to .csv file
    os.makedirs(outdir, exist_ok=True)
    final_df = pd.DataFrame(result_files)
    final_df.to_csv(
        f"{outdir}/evaluation_post_hoc_criterion_{len(seeds)}_{len(info_types)}_{n_episodes}.csv",
        index=False)

    print("All experiments complete. Results saved.")
    return final_df


if __name__ == "__main__":
    idx = [0, 1, 2, 3, 4]
    required_dims = {0, 1}
    run_experiments(
        seeds=list(range(0, 10)),
        info_types=[list(comb) for r in range(1, len(idx) + 1) for comb in itertools.combinations(idx, r) if
                    required_dims.issubset(comb)],
        n_episodes=2500,
        horizon=25,
        outdir="results",
        num_workers=64
    )
