"""GRU hidden state analysis via dimensionality reduction."""

import torch
import numpy as np
import networkx as nx
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE


def collect_hidden_states(policy, obs_dataset):
    policy.eval()
    all_hidden, all_positions, all_timesteps, all_traj_ids = [], [], [], []

    with torch.no_grad():
        for tid, d in enumerate(obs_dataset):
            h = torch.zeros(1, 1, policy.hidden_dim)
            for t in range(len(d["states"])):
                obs_t = d["obs"][t].unsqueeze(0).unsqueeze(0)
                encoded = policy.obs_encoder(obs_t)
                _, h = policy.gru(encoded, h)
                all_hidden.append(h.squeeze().numpy().copy())
                all_positions.append(d["states"][t])
                all_timesteps.append(t)
                all_traj_ids.append(tid)

    return (
        np.array(all_hidden),
        np.array(all_positions),
        np.array(all_timesteps),
        np.array(all_traj_ids),
    )


def collect_hidden_states_generic(policy, obs_dataset):
    """Collect hidden states from any architecture exposing get_recurrent_output()."""
    policy.eval()
    all_hidden, all_positions, all_timesteps, all_traj_ids = [], [], [], []

    with torch.no_grad():
        for tid, d in enumerate(obs_dataset):
            obs = d["obs"].unsqueeze(0)
            if hasattr(policy, "get_recurrent_output"):
                h_seq = policy.get_recurrent_output(obs).squeeze(0)
            else:
                h = torch.zeros(1, 1, policy.hidden_dim)
                h_list = []
                for t in range(len(d["states"])):
                    obs_t = d["obs"][t].unsqueeze(0).unsqueeze(0)
                    encoded = policy.obs_encoder(obs_t)
                    _, h = policy.gru(encoded, h)
                    h_list.append(h.squeeze().numpy().copy())
                h_seq = None
                all_hidden.extend(h_list)
                all_positions.extend(d["states"][:len(h_list)])
                all_timesteps.extend(range(len(h_list)))
                all_traj_ids.extend([tid] * len(h_list))
                continue

            T = len(d["states"])
            for t in range(T):
                all_hidden.append(h_seq[t].numpy().copy())
                all_positions.append(d["states"][t])
                all_timesteps.append(t)
                all_traj_ids.append(tid)

    return (
        np.array(all_hidden),
        np.array(all_positions),
        np.array(all_timesteps),
        np.array(all_traj_ids),
    )


def run_dimensionality_reduction(hidden_states, n_pca=50, perplexity=30):
    """PCA then t-SNE on hidden states; returns (hidden_2d, pca, hidden_pca)."""
    n_components = min(n_pca, hidden_states.shape[1], hidden_states.shape[0])
    pca = PCA(n_components=n_components)
    hidden_pca = pca.fit_transform(hidden_states)

    tsne_perp = min(perplexity, hidden_states.shape[0] // 4)
    tsne = TSNE(n_components=2, perplexity=tsne_perp, max_iter=2000,
                init="pca", random_state=42)
    hidden_2d = tsne.fit_transform(hidden_pca)

    return hidden_2d, pca, hidden_pca
