##
## (c) Anonymous authors (2026)
##
## > A2C training script
##
##

import ast
import itertools
import multiprocessing
import os
import random
from dataclasses import dataclass

import numpy as np
import pandas as pd
import scipy
import torch
import torch.optim as optim
import wandb
from sklearn.model_selection import KFold
from tqdm import tqdm

from informativeness.synthetic_informed_POMDP.envs import RandomSyntheticInformedPOMDP
from informativeness.synthetic_informed_POMDP.models import RecurrentActor, RecurrentCritic, AsymmetricRecurrentCritic

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["WANDB_MODE"] = os.getenv("WANDB_MODE", "online")
os.environ["WANDB_SILENT"] = os.getenv("WANDB_SILENT", "true")

GAMMA = 0.99
NOISE_FACTOR = 0.0

# Critic parameters
CRITIC_HIDDEN = 64
CRITIC_LR = 1e-4

# Actor parameters
ACTOR_HIDDEN = 64
ACTOR_LR = 1e-4

N_ITERS = 15000
EVAL_INTERVAL = 50
N_EVAL_EPISODES = 50

DEVICE = torch.device("cpu")


@dataclass
class InformedEpisode:
    states: list
    obs: list
    info: list
    actions: list
    rewards: list
    returns: list


def collect_policy_episodes(env, actor, n_episodes, horizon, info_type=None):
    episodes = []

    actor.eval()
    with torch.no_grad():
        for _ in range(n_episodes):
            s = env.reset()
            states, obs, actions, rewards, info = [], [], [], [], []

            h = None

            for t in range(horizon):
                o = env.generate_obs(s)
                o_tensor = torch.tensor(o, dtype=torch.float32).view(1, 1, -1)

                probs, h = actor(o_tensor, h)
                dist = torch.distributions.Categorical(probs.squeeze(0).squeeze(0))
                a = dist.sample().item()

                r, s_next = env.step(a)

                states.append(s)
                obs.append(o)
                actions.append(a)
                rewards.append(r)

                if info_type is not None:
                    info.append(env.generate_latent(s)[info_type])

                s = s_next

            # discounted returns
            G = []
            ret = 0
            for r in reversed(rewards):
                ret = r + GAMMA * ret
                G.insert(0, ret)

            episodes.append(InformedEpisode(states, obs, info, actions, rewards, G))

    actor.train()
    return episodes


def train_actor_critic(actor, critic, actor_opt, critic_opt, env, episodes, critic_type="symmetric"):
    for ep in episodes:
        T = len(ep.obs)

        obs = torch.tensor(ep.obs, dtype=torch.float32).unsqueeze(0)
        actions = torch.tensor(ep.actions).unsqueeze(0)

        # One-hot actions for critic
        a_onehot = torch.nn.functional.one_hot(actions, num_classes=env.A).float()
        critic_in = torch.cat([obs, a_onehot], dim=-1)

        if critic_type == "asymmetric":
            info = torch.tensor(ep.info, dtype=torch.float32).unsqueeze(0)
            values, _ = critic(critic_in, info)
        else:
            values, _ = critic(critic_in)

        # Policy forward pass
        probs, _ = actor(obs)
        dist = torch.distributions.Categorical(probs)
        log_probs = dist.log_prob(actions)

        # TD targets and advantages
        targets = []
        for t in range(T):
            r_t = ep.rewards[t]
            if t < T - 1:
                targets.append(r_t + GAMMA * values[0, t + 1].detach())
            else:
                targets.append(torch.tensor(r_t))

        targets = torch.stack(targets)
        advantages = targets - values[0]

        # Losses
        actor_loss = -(log_probs.squeeze(0) * advantages.detach()).mean()
        critic_loss = advantages.pow(2).mean()

        actor_opt.zero_grad()
        critic_opt.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        actor_opt.step()
        critic_opt.step()

    return actor, critic, actor_opt, critic_opt, actor_loss.detach(), critic_loss.detach()


def compute_i_from_s(states, env, info_type, seed):
    if info_type == "noise":
        rng = np.random.default_rng(seed)
        latent_dim = 2
        return rng.normal(size=(len(states), latent_dim))
    else:
        privileged_information = np.stack([env.generate_latent(s)[info_type] for s in states], axis=0)
        if privileged_information.shape[1] == 1:
            privileged_information = privileged_information.ravel()
        return privileged_information


def generate_fixed_env(rnd_seed):
    num_states = 20
    latent_dim = 5
    obs_dim = 2
    num_actions = 4

    # shared seed to initialize state feature and transitions
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    fixed_latents = np.random.randn(num_states, latent_dim)

    reward_weights = [0.0001, 0.0001, -0.0001, -1.0, 1.0]

    transitions = np.zeros((num_states, num_actions, num_states))
    for s in range(num_states):
        for a in range(num_actions):
            for s_prime in range(num_states):
                rnd_prob = np.random.rand()
                if rnd_prob < 0.25:
                    transitions[s, a, s_prime] = np.random.uniform(0, 1)
            if np.sum(transitions[s, a]) == 0:
                transitions[s, a, np.random.randint(num_states)] = 1.0
            transitions[s, a] /= np.sum(transitions[s, a])

    return RandomSyntheticInformedPOMDP(num_states=num_states, obs_dim=obs_dim, latent_dim=latent_dim,
                                        num_actions=num_actions, fixed_latent_map=fixed_latents,
                                        transitions=transitions, reward_weights=reward_weights, seed=rnd_seed)


def chunks(lst, n):
    """Create successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def run_worker_and_write(args):
    n_episodes, horizon, info_type, seed = args
    outdir = "tmp_results"

    wandb.init(
        entity="ebida",
        project="synthetic-pomdp",
        group="ebida",  # or your team/user name
        config={
            "n_episodes": n_episodes,
            "horizon": horizon,
            "info_type": info_type,
            "seed": seed,
            "critic_hidden": CRITIC_HIDDEN,
            "actor_hidden": ACTOR_HIDDEN,
            "critic_lr": CRITIC_LR,
            "actor_lr": ACTOR_LR,
            "n_iterations": N_ITERS,
            "eval_interval": EVAL_INTERVAL,
            "n_eval_episodes": N_EVAL_EPISODES,
            "gamma": GAMMA,
            "noise_factor": NOISE_FACTOR,
        },
        reinit=True,
        name=f"seed={seed}_info={info_type}",
    )

    env = generate_fixed_env(rnd_seed=seed)
    actor = RecurrentActor(env.obs_dim, env.A, ACTOR_HIDDEN).to(DEVICE)

    if info_type is None:
        critic = RecurrentCritic(env.obs_dim, env.A, CRITIC_HIDDEN).to(DEVICE)
    else:
        critic = AsymmetricRecurrentCritic(env.obs_dim, env.A, CRITIC_HIDDEN, len(info_type)).to(DEVICE)

    actor_opt = optim.Adam(actor.parameters(), lr=ACTOR_LR)
    critic_opt = optim.Adam(critic.parameters(), lr=CRITIC_LR)

    if info_type is None:
        critic_type = "symmetric"
    else:
        critic_type = "asymmetric"

    returns_train, return_eval = [], []
    eval_log = pd.DataFrame(
        columns=["iteration", "seed", "info_type", "actor_loss", "critic_loss", "eval_return", "train_return"])

    for it in range(N_ITERS):
        episodes = collect_policy_episodes(env, actor, n_episodes, horizon, info_type)
        actor, critic, actor_opt, critic_opt, actor_loss, critic_loss = train_actor_critic(actor, critic, actor_opt,
                                                                                           critic_opt, env, episodes,
                                                                                           critic_type=critic_type)
        avg_return_train = np.mean([sum(ep.rewards) for ep in episodes])
        returns_train.append(avg_return_train)

        if (it + 1) % EVAL_INTERVAL == 0:
            eval_episodes = collect_policy_episodes(env, actor, N_EVAL_EPISODES, horizon, info_type)
            avg_return_eval = np.mean([sum(ep.rewards) for ep in eval_episodes])
            return_eval.append(avg_return_eval)

            eval_log = pd.concat([
                eval_log,
                pd.DataFrame([{
                    "iteration": it + 1,
                    "seed": seed,
                    "info_type": info_type,
                    "actor_loss": actor_loss,
                    "critic_loss": critic_loss,
                    "eval_return": avg_return_eval,
                    "train_return": avg_return_train
                }])
            ], ignore_index=True)

        else:
            avg_return_eval = None

        wandb.log({
            "iter": it,
            "avg_return_train": avg_return_train,
            "actor_loss": actor_loss,
            "critic_loss": critic_loss,
            "avg_return_eval": avg_return_eval
        })

    wandb.finish()

    if info_type is None:
        info_type_str = "sym"
    else:
        indices = ast.literal_eval(info_type)
        info_type_str = "-".join(str(i) for i in indices)

    os.makedirs(outdir, exist_ok=True)
    eval_log.to_csv(
        f"{outdir}/a2c_training_{N_ITERS}_{n_episodes}_{info_type_str}_{seed}.csv",
        index=False)

    return eval_log


def run_experiments(seeds, info_types, n_episodes, horizon, outdir="models",
                    num_workers=None):
    if num_workers is None:
        num_workers = multiprocessing.cpu_count()

    result_files = []
    info_types.append(None)
    experiment_configs = [(n_episodes, horizon, info_type, seed) for info_type in
                          info_types for seed in seeds]

    batch_size = 90
    with tqdm(total=len(experiment_configs)) as pbar:
        for batch in chunks(experiment_configs, batch_size):
            with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
                for result in pool.imap_unordered(run_worker_and_write, batch):
                    result_files.append(result)
                    pbar.update(1)

    print("All experiments complete. Results saved.")
    return True


if __name__ == "__main__":
    idx = [0, 1, 2, 3, 4]
    required_dims = {0, 1}
    run_experiments(
        seeds=list(range(0, 10)),
        info_types=[list(comb) for r in range(1, len(idx) + 1) for comb in itertools.combinations(idx, r) if
                    required_dims.issubset(comb)],
        n_episodes=16,
        horizon=25,
        outdir="results",
        num_workers=90
    )
