import argparse
import hashlib
import json
import math
import os
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from datasets import Dataset
from sklearn.decomposition import PCA
from tabulate import tabulate
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def save_figure_pdf_png(outfile: str, dpi: int = 300) -> List[str]:
    """Save the current matplotlib figure as both PDF and PNG when either is requested."""
    root, ext = os.path.splitext(outfile)
    ext = ext.lower()

    if ext == ".pdf":
        outfiles = [outfile, root + ".png"]
    elif ext == ".png":
        outfiles = [outfile, root + ".pdf"]
    else:
        outfiles = [outfile]

    saved = []
    for path in outfiles:
        if path not in saved:
            plt.savefig(path, dpi=dpi)
            saved.append(path)
    return saved


def load_progressive_dataset(dataset_path: str) -> Dataset:
    """Load a progressive embeddings dataset from disk."""
    return Dataset.load_from_disk(dataset_path)


def flatten_embedding(row: Dict[str, Any]) -> np.ndarray:
    """Flatten embedding from a dataset row."""
    emb = torch.tensor(row["embedding"], dtype=torch.float32)
    return emb.reshape(-1).detach().cpu().numpy()


def get_experiment_cache_file(dataset_path: str, model_checkpoint: str) -> str:
    cache_dir = os.path.join(os.path.dirname(dataset_path), ".cache")
    os.makedirs(cache_dir, exist_ok=True)
    cache_key = f"{dataset_path}:{model_checkpoint}"
    cache_hash = hashlib.md5(cache_key.encode()).hexdigest()
    return os.path.join(cache_dir, f"visualize_multiple_trajectories_{cache_hash}.json")


def load_experiment_cache(dataset_path: Optional[str], model_checkpoint: Optional[str]) -> Tuple[Dict[str, Any], Optional[str]]:
    if not dataset_path or not model_checkpoint:
        return {}, None
    cache_file = get_experiment_cache_file(dataset_path, model_checkpoint)
    if not os.path.exists(cache_file):
        return {}, cache_file
    try:
        with open(cache_file, "r") as f:
            cache_data = json.load(f)
            return cache_data if isinstance(cache_data, dict) else {}, cache_file
    except (json.JSONDecodeError, IOError) as e:
        print(f"Warning: Failed to load cache file {cache_file}: {e}")
        return {}, cache_file


def save_experiment_cache(cache_file: Optional[str], cache_data: Dict[str, Any]) -> None:
    if cache_file is None:
        return
    try:
        with open(cache_file, "w") as f:
            json.dump(cache_data, f, indent=2)
    except IOError as e:
        print(f"Warning: Failed to save cache file {cache_file}: {e}")


def get_cache_metrics(cache_data: Dict[str, Any]) -> Dict[str, Any]:
    metrics = cache_data.get("metrics")
    if isinstance(metrics, dict):
        return metrics
    metrics = {}
    cache_data["metrics"] = metrics
    return metrics


def get_metric_map(cache_data: Dict[str, Any], metric_name: str) -> Dict[str, Any]:
    metrics = get_cache_metrics(cache_data)
    metric_map = metrics.get(metric_name)
    if isinstance(metric_map, dict):
        return metric_map
    metric_map = {}
    metrics[metric_name] = metric_map
    return metric_map


def get_metric_value(cache_data: Dict[str, Any], metric_name: str) -> Optional[Any]:
    metrics = get_cache_metrics(cache_data)
    return metrics.get(metric_name)


def set_metric_value(cache_data: Dict[str, Any], metric_name: str, value: Any) -> bool:
    metrics = get_cache_metrics(cache_data)
    if metrics.get(metric_name) == value:
        return False
    metrics[metric_name] = value
    return True


def set_metric_map_value(cache_data: Dict[str, Any], metric_name: str, key: Any, value: Any) -> bool:
    metric_map = get_metric_map(cache_data, metric_name)
    key_str = str(key)
    if metric_map.get(key_str) == value:
        return False
    metric_map[key_str] = value
    return True


def cache_has_metrics(
    cache_data: Dict[str, Any],
    sample_ids: List[int],
    require_random_proj: bool,
    require_info_gain: bool,
    require_embedding_stats: bool,
    allow_info_gain_from_dataset: bool,
) -> bool:
    if not sample_ids:
        return False
    metrics = get_cache_metrics(cache_data)
    required_per_sample = ["trajectory_length", "pca_99_var"]
    if require_random_proj:
        required_per_sample.append("random_proj_99_var")
    if require_info_gain and not allow_info_gain_from_dataset:
        required_per_sample.append("information_gain")

    for metric_name in required_per_sample:
        metric_map = metrics.get(metric_name, {})
        if not isinstance(metric_map, dict):
            return False
        for sid in sample_ids:
            if str(sid) not in metric_map:
                return False

    if "pca_99_var_all_embeds" not in metrics:
        return False
    if require_embedding_stats and "embedding_statistics" not in metrics:
        return False

    return True


def get_sample_ids(ds: Dataset) -> List[int]:
    if "sample_id" not in ds.column_names:
        return []
    try:
        sample_ids = ds.unique("sample_id")
    except Exception:
        sample_ids = []
        for i in range(len(ds)):
            try:
                sid = ds[i].get("sample_id")
            except Exception:
                continue
            if sid is not None:
                sample_ids.append(sid)
    return sorted({int(sid) for sid in sample_ids if sid is not None})


def get_final_stage(stages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    if len(stages) == 0:
        return None

    def get_sort_key(stage: Dict[str, Any]) -> Tuple[int, int]:
        stage_idx = stage.get("stage_index")
        stage_seq_len = stage.get("stage_seq_len")
        if stage_idx is not None:
            return (int(stage_idx), int(stage_seq_len) if stage_seq_len is not None else -1)
        if stage_seq_len is not None:
            return (-1, int(stage_seq_len))
        return (-1, -1)

    return max(stages, key=get_sort_key)


def filter_records(
    ds: Dataset,
    sample_id: Optional[int] = None,
    stage_index: Optional[int] = None,
    dataset_path: Optional[str] = None,
    model_checkpoint: Optional[str] = None,
    check_cache: bool = True,
) -> List[Dict[str, Any]]:
    """Filter dataset records by sample_id and/or stage_index.

    Args:
        ds: Dataset to filter
        sample_id: Optional sample_id to filter by
        stage_index: Optional stage_index to filter by
        dataset_path: Optional dataset path for cache checking
        model_checkpoint: Optional model checkpoint for cache checking
        check_cache: If True, check if all cache files exist and remove embedding column if they do

    Returns:
        List of filtered records
    """
    rows: List[Dict[str, Any]] = []

    ds = ds.remove_columns(["orig_embedding", "initialization_embedding"])
    if "low_dim_prjoection_b" in ds.column_names:
        ds = ds.remove_columns(["low_dim_prjoection_b"])
    if "low_dim_prjoection_w" in ds.column_names:
        ds = ds.remove_columns(["low_dim_prjoection_w"])

    # Check if we can remove embedding column (if all cache metrics exist)
    if check_cache and dataset_path is not None and "embedding" in ds.column_names:
        if model_checkpoint is None and len(ds) > 0:
            try:
                model_checkpoint = ds[0].get("model_checkpoint")
            except Exception:
                model_checkpoint = None

        sample_ids_list = get_sample_ids(ds)
        if model_checkpoint is not None and sample_ids_list:
            cache_data, _ = load_experiment_cache(dataset_path, model_checkpoint)
            require_random_proj = os.environ.get("VISUALIZE_MULTIPLE_TRAJECTORIES_COMPUTE_RAND_PROJ") == "1"
            require_info_gain = os.environ.get("VISUALIZE_MULTIPLE_TRAJECTORIES_COMPUTE_IG") == "1"
            require_emb_stats = os.environ.get("VISUALIZE_MULTIPLE_TRAJECTORIES_COMPUTE_EMB_STATS") == "1"
            allow_info_gain_from_dataset = "information_gain_bits" in ds.column_names
            if cache_has_metrics(
                cache_data,
                sample_ids_list,
                require_random_proj=require_random_proj,
                require_info_gain=require_info_gain,
                require_embedding_stats=require_emb_stats,
                allow_info_gain_from_dataset=allow_info_gain_from_dataset,
            ):
                print("Drop embeddings")
                ds = ds.remove_columns(["embedding"])

    for i in tqdm(range(len(ds)), desc="Filtering records"):
        r = ds[i]
        if sample_id is not None and int(r.get("sample_id", -1)) != int(sample_id):
            continue
        if stage_index is not None and int(r.get("stage_index", -1)) != int(stage_index):
            continue
        rows.append(r)
    return rows


def collate_stages_by_sample(
    rows: List[Dict[str, Any]],
) -> Dict[int, List[Dict[str, Any]]]:
    """Group rows by sample_id and sort by stage_index."""
    by_sid: Dict[int, List[Dict[str, Any]]] = {}
    for r in rows:
        sid = int(r.get("sample_id", -1))
        if sid not in by_sid:
            by_sid[sid] = []
        by_sid[sid].append(r)
    for sid in by_sid:
        by_sid[sid].sort(key=lambda x: int(x.get("stage_index", 0)))
    return by_sid


def compute_num_pca_explained_99_var(
    embeddings: List[np.ndarray],
    cache_data: Optional[Dict[str, Any]] = None,
    cache_key_suffix: Optional[str] = None,
) -> float:
    """Compute cumulative explained variance using PCA with 4 components.

    Args:
        embeddings: List of flattened embedding arrays
        cache_data: Optional cache dictionary for the experiment.
        cache_key_suffix: Optional suffix for per-sample cache key (e.g., sample_id).

    Returns:
        Cumulative explained variance ratio (0.0 to 1.0), or NaN if not computable
    """
    if len(embeddings) < 2:
        return float("nan")

    # Stack embeddings: [n_samples, n_features]
    X = np.stack(embeddings, axis=0)

    # Need at least 2 samples for PCA
    if X.shape[0] < 2:
        return float("nan")

    n_samples, n_features = X.shape

    if cache_data is not None and cache_key_suffix is not None:
        metric_map = get_metric_map(cache_data, "pca_99_var")
        cached_result = metric_map.get(str(cache_key_suffix))
        if cached_result is not None:
            return float(cached_result)

    # Fit PCA with up to 4 components
    max_PCA_components = min(512, n_samples - 1, n_features)
    if max_PCA_components < 1:
        return float("nan")

    pca = PCA(n_components=max_PCA_components, random_state=42)
    pca.fit(X)
    explained_var_ratio = pca.explained_variance_ratio_

    # Return cumulative explained variance
    cumulative_var = np.cumsum(explained_var_ratio)
    num_pca_for99_var = (cumulative_var < 0.99).sum()
    if num_pca_for99_var == max_PCA_components:
        num_pca_for99_var = -1

    result = float(num_pca_for99_var)

    if cache_data is not None and cache_key_suffix is not None:
        set_metric_map_value(cache_data, "pca_99_var", cache_key_suffix, result)

    return result


def compute_num_random_projections_explained_99_var(
    embeddings: List[np.ndarray],
    n_projections: int = 1000,
    random_state: int = 42,
    cache_data: Optional[Dict[str, Any]] = None,
    cache_key_suffix: Optional[str] = None,
) -> float:
    """Compute how many random projections explain 99% of variation in embeddings path.

    Args:
        embeddings: List of flattened embedding arrays
        n_projections: Number of random projection directions to generate
        random_state: Random seed for reproducibility
        cache_data: Optional cache dictionary for the experiment
        cache_key_suffix: Optional suffix for per-sample cache key (e.g., sample_id)

    Returns:
        Number of random projections needed to explain 99% variance, or NaN if not computable
    """
    if len(embeddings) < 2:
        return float("nan")

    if cache_data is not None and cache_key_suffix is not None:
        metric_map = get_metric_map(cache_data, "random_proj_99_var")
        cached_result = metric_map.get(str(cache_key_suffix))
        if cached_result is not None:
            return float(cached_result)

    # Stack embeddings: [n_samples, n_features]
    X = np.stack(embeddings, axis=0)

    # Need at least 2 samples
    if X.shape[0] < 2:
        return float("nan")

    n_samples, n_features = X.shape

    # Center the data
    X_centered = X - X.mean(axis=0, keepdims=True)

    # Generate random projection directions (unit vectors)
    rng = np.random.RandomState(random_state)
    random_directions = rng.randn(n_projections, n_features)
    # Normalize to unit vectors
    norms = np.linalg.norm(random_directions, axis=1, keepdims=True)
    random_directions = random_directions / (norms + 1e-12)

    # Project embeddings onto each random direction
    projections = X_centered @ random_directions.T  # [n_samples, n_projections]

    # Compute variance along each projection direction
    variances = np.var(projections, axis=0)  # [n_projections]

    # Sort by variance (descending)
    sorted_indices = np.argsort(variances)[::-1]
    sorted_variances = variances[sorted_indices]

    # Compute cumulative variance
    total_variance = np.sum(sorted_variances)
    if total_variance == 0:
        return float("nan")

    cumulative_var = np.cumsum(sorted_variances) / total_variance

    # Find number of projections needed for 99% variance
    num_projections = (cumulative_var < 0.99).sum() + 1
    if num_projections > n_projections:
        num_projections = -1

    result = float(num_projections)
    if cache_data is not None and cache_key_suffix is not None:
        set_metric_map_value(cache_data, "random_proj_99_var", cache_key_suffix, result)

    return result


def compute_trajectory_length(
    embeddings: List[np.ndarray],
    cache_data: Optional[Dict[str, Any]] = None,
    cache_key_suffix: Optional[str] = None,
) -> float:
    """Compute trajectory length (sum of L2 distances between consecutive embeddings).

    Args:
        embeddings: List of flattened embedding arrays
        cache_data: Optional cache dictionary for the experiment.
        cache_key_suffix: Optional suffix for per-sample cache key (e.g., sample_id).

    Returns:
        Trajectory length (sum of distances), or 0.0 if less than 2 embeddings
    """
    if len(embeddings) < 2:
        return 0.0

    # Stack embeddings: [n_samples, n_features]
    X = np.stack(embeddings, axis=0)
    n_samples, n_features = X.shape

    if cache_data is not None and cache_key_suffix is not None:
        metric_map = get_metric_map(cache_data, "trajectory_length")
        cached_result = metric_map.get(str(cache_key_suffix))
        if cached_result is not None:
            return float(cached_result)

    # Compute trajectory length
    trajectory_length = 0.0
    for i in range(len(embeddings) - 1):
        dist = np.linalg.norm(embeddings[i + 1] - embeddings[i])
        trajectory_length += dist

    result = float(trajectory_length)

    if cache_data is not None and cache_key_suffix is not None:
        set_metric_map_value(cache_data, "trajectory_length", cache_key_suffix, result)

    return result


def compute_information_gain(
    rows: List[Dict[str, Any]],
    model_checkpoint: Optional[str] = None,
    device: Optional[torch.device] = None,
    cache_data: Optional[Dict[str, Any]] = None,
) -> List[float]:
    """Compute information gain (CE-reduction) for all samples in the dataset.

    Information Gain = H_LM - H_LM+[mem]
    where H_LM is cross-entropy without memory vector and H_LM+[mem] is with memory vector.

    Args:
        rows: List of dataset rows, each containing 'text', 'embedding', 'num_compression_tokens', etc.
        model_checkpoint: Model checkpoint path. If None, tries to extract from first row.
        device: Device to run computation on. If None, uses CUDA if available.
        cache_data: Optional cache dictionary for the experiment.

    Returns:
        List of information gain values (one per sample, using final stage embedding)
    """

    if os.environ.get("VISUALIZE_MULTIPLE_TRAJECTORIES_COMPUTE_IG") != "1":
        return []

    if len(rows) == 0:
        return []

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get model checkpoint from first row if not provided
    if model_checkpoint is None:
        model_checkpoint = rows[0].get("model_checkpoint")
        if not model_checkpoint:
            print("Warning: model_checkpoint not found in dataset, skipping information gain computation")
            return []

    sample_ids = sorted({int(row.get("sample_id", -1)) for row in rows if row.get("sample_id") is not None})
    if cache_data is not None and sample_ids:
        metric_map = get_metric_map(cache_data, "information_gain")
        if all(str(sid) in metric_map for sid in sample_ids):
            return [float(metric_map[str(sid)]) for sid in sample_ids]

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    model = AutoModelForCausalLM.from_pretrained(
        model_checkpoint,
        torch_dtype=torch.bfloat16,
    ).to(device)
    model.eval()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Group rows by sample_id and get final stage for each sample
    by_sid = collate_stages_by_sample(rows)

    # For each sample, get the final stage (highest stage_index or highest stage_seq_len)
    information_gains = []

    for sid, stages in by_sid.items():
        if len(stages) == 0:
            continue

        final_stage = get_final_stage(stages)
        if final_stage is None:
            continue

        text = final_stage.get("text", "")
        if not isinstance(text, str) or text.strip() == "":
            continue

        embedding = final_stage.get("embedding")
        if embedding is None:
            continue

        num_compression_tokens = int(final_stage.get("num_compression_tokens", 1))

        # Tokenize text
        enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)

        # Compute H_LM: cross-entropy without memory vector
        with torch.no_grad():
            outputs_lm = model(input_ids=input_ids, attention_mask=attention_mask)
            logits_lm = outputs_lm.logits  # [1, seq_len, vocab_size]

            # Compute cross-entropy: shift logits and labels for next-token prediction
            shift_logits_lm = logits_lm[:, :-1, :].contiguous()
            shift_labels_lm = input_ids[:, 1:].contiguous()
            shift_mask_lm = attention_mask[:, 1:].contiguous()

            # Flatten for cross-entropy
            shift_logits_lm_flat = shift_logits_lm.view(-1, shift_logits_lm.size(-1))
            shift_labels_lm_flat = shift_labels_lm.view(-1)
            shift_mask_lm_flat = shift_mask_lm.view(-1)

            # Mask out padding
            valid_mask = shift_mask_lm_flat.bool()
            if valid_mask.sum() == 0:
                continue

            ce_lm = F.cross_entropy(
                shift_logits_lm_flat[valid_mask],
                shift_labels_lm_flat[valid_mask],
                reduction="sum",
            )
            # Convert from nats to bits: divide by ln(2)
            H_LM = ce_lm.item() / math.log(2)

        # Compute H_LM+[mem]: cross-entropy with memory vector
        embedding_tensor = torch.tensor(embedding, dtype=torch.bfloat16, device=device)
        if embedding_tensor.ndim == 1:
            # Reshape if needed: assume [num_compression_tokens * hidden_size] -> [num_compression_tokens, hidden_size]
            hidden_size = model.config.hidden_size
            if embedding_tensor.shape[0] == num_compression_tokens * hidden_size:
                embedding_tensor = embedding_tensor.reshape(num_compression_tokens, hidden_size)
            else:
                embedding_tensor = embedding_tensor.unsqueeze(0)
        if embedding_tensor.ndim == 2:
            embedding_tensor = embedding_tensor.unsqueeze(0)  # [1, num_compression_tokens, hidden_size]

        # Get token embeddings
        token_embeddings = model.model.embed_tokens(input_ids)  # [1, seq_len, hidden_size]

        # Concatenate compression tokens with token embeddings
        compression_attention_mask = torch.ones((1, num_compression_tokens), device=device, dtype=attention_mask.dtype)
        united_token_embeddings = torch.cat([embedding_tensor, token_embeddings], dim=1)
        united_attention_mask = torch.cat([compression_attention_mask, attention_mask], dim=1)

        with torch.no_grad():
            outputs_mem = model(inputs_embeds=united_token_embeddings.to(torch.bfloat16), attention_mask=united_attention_mask)
            logits_mem = outputs_mem.logits  # [1, num_compression_tokens + seq_len, vocab_size]

            # Align logits: slice from num_compression_tokens-1 to -1, then shift for next-token prediction
            aligned_logits_mem = logits_mem[:, num_compression_tokens:, :]  # [1, seq_len, vocab_size]

            # Compute cross-entropy: shift for next-token prediction
            shift_logits_mem = aligned_logits_mem[:, :-1, :].contiguous()
            shift_labels_mem = input_ids[:, 1:].contiguous()
            shift_mask_mem = attention_mask[:, 1:].contiguous()

            # Flatten for cross-entropy
            shift_logits_mem_flat = shift_logits_mem.view(-1, shift_logits_mem.size(-1))
            shift_labels_mem_flat = shift_labels_mem.view(-1)
            shift_mask_mem_flat = shift_mask_mem.view(-1)

            # Mask out padding
            valid_mask = shift_mask_mem_flat.bool()
            if valid_mask.sum() == 0:
                continue

            ce_mem = F.cross_entropy(
                shift_logits_mem_flat[valid_mask],
                shift_labels_mem_flat[valid_mask],
                reduction="sum",
            )
            # Convert from nats to bits: divide by ln(2)
            H_LM_mem = ce_mem.item() / math.log(2)

        # Information gain = H_LM - H_LM+[mem]
        info_gain = H_LM - H_LM_mem
        information_gains.append(info_gain)
        if cache_data is not None:
            set_metric_map_value(cache_data, "information_gain", sid, info_gain)

    return information_gains


def extract_information_gain_from_dataset(rows: List[Dict[str, Any]]) -> List[float]:
    """Extract information gain values from dataset rows.

    Args:
        rows: List of dataset rows, each potentially containing 'information_gain_bits'

    Returns:
        List of information gain values (one per sample, using final stage embedding)
    """
    if len(rows) == 0:
        return []

    # Group rows by sample_id and get final stage for each sample
    by_sid = collate_stages_by_sample(rows)

    # For each sample, get the final stage (highest stage_index or highest stage_seq_len)
    information_gains = []

    for sid, stages in by_sid.items():
        if len(stages) == 0:
            continue

        final_stage = get_final_stage(stages)
        if final_stage is None:
            continue

        # Extract information_gain_bits from the dataset
        info_gain = final_stage.get("information_gain_bits")
        if info_gain is not None:
            try:
                information_gains.append(float(info_gain))
            except (ValueError, TypeError):
                continue

    return information_gains


def compute_embedding_statistics(
    rows: List[Dict[str, Any]],
    model_checkpoint: Optional[str] = None,
    device: Optional[torch.device] = None,
    cache_data: Optional[Dict[str, Any]] = None,
) -> str:
    """Compute norm statistics (mean ± std of L2 norms) for compression embeddings vs regular vocab tokens.

    Args:
        rows: List of dataset rows, each containing 'embedding', 'num_compression_tokens', etc.
        model_checkpoint: Model checkpoint path. If None, tries to extract from first row.
        device: Device to run computation on. If None, uses CUDA if available.
        cache_data: Optional cache dictionary for the experiment.

    Returns:
        Formatted string with "comp_norm_avg±comp_norm_std / vocab_norm_avg±vocab_norm_std" or "nan" if not computable
    """
    if os.environ.get("VISUALIZE_MULTIPLE_TRAJECTORIES_COMPUTE_EMB_STATS") != "1":
        return "nan"

    if len(rows) == 0:
        return "nan"

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get model checkpoint from first row if not provided
    if model_checkpoint is None:
        model_checkpoint = rows[0].get("model_checkpoint")
        if not model_checkpoint:
            return "nan"

    if cache_data is not None:
        cached_result = get_metric_value(cache_data, "embedding_statistics")
        if cached_result is not None:
            return str(cached_result)

    # Load model and tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        model = AutoModelForCausalLM.from_pretrained(
            model_checkpoint,
            torch_dtype=torch.bfloat16,
        ).to(device)
        model.eval()
    except Exception as e:
        print(f"Warning: Failed to load model for embedding statistics: {e}")
        return "nan"

    # Get all vocab token embeddings
    vocab_size = len(tokenizer)
    vocab_token_ids = torch.arange(vocab_size, device=device)
    with torch.no_grad():
        vocab_embeddings = model.model.embed_tokens(vocab_token_ids)  # [vocab_size, hidden_size]
        vocab_embeddings_np = vocab_embeddings.float().cpu().numpy()

    # Group rows by sample_id and get final stage for each sample
    by_sid = collate_stages_by_sample(rows)

    # Collect all compression token embeddings from final stages
    compression_token_embeddings = []

    for sid, stages in by_sid.items():
        if len(stages) == 0:
            continue

        final_stage = get_final_stage(stages)
        if final_stage is None:
            continue

        embedding = final_stage.get("embedding")
        if embedding is None:
            continue

        num_compression_tokens = int(final_stage.get("num_compression_tokens", 1))
        embedding_tensor = torch.tensor(embedding, dtype=torch.float32)
        hidden_size = model.config.hidden_size

        # Reshape embedding if needed
        if embedding_tensor.ndim == 1:
            if embedding_tensor.shape[0] == num_compression_tokens * hidden_size:
                embedding_tensor = embedding_tensor.reshape(num_compression_tokens, hidden_size)
            else:
                embedding_tensor = embedding_tensor.unsqueeze(0)
        elif embedding_tensor.ndim == 2:
            # Already in [num_compression_tokens, hidden_size] format
            pass
        else:
            continue

        # Add each compression token embedding separately (not flattened)
        compression_token_embeddings.append(embedding_tensor.numpy())

    if len(compression_token_embeddings) == 0:
        return "nan"

    # Stack compression token embeddings: [total_tokens, hidden_size]
    compression_token_embeddings_np = np.vstack(compression_token_embeddings)  # [total_compression_tokens, hidden_size]

    # Compute L2 norms for compression token embeddings (one norm per token)
    compression_norms = np.linalg.norm(compression_token_embeddings_np, axis=1)  # [total_compression_tokens]
    comp_norm_avg = np.mean(compression_norms)
    comp_norm_std = np.std(compression_norms)

    # Compute L2 norms for vocab embeddings (one norm per token)
    vocab_norms = np.linalg.norm(vocab_embeddings_np, axis=1)  # [vocab_size]
    vocab_norm_avg = np.mean(vocab_norms)
    vocab_norm_std = np.std(vocab_norms)

    # Format as "comp_norm_avg±comp_norm_std / vocab_norm_avg±vocab_norm_std"
    result = f"{comp_norm_avg:.4f}±{comp_norm_std:.4f} / {vocab_norm_avg:.4f}±{vocab_norm_std:.4f}"

    if cache_data is not None:
        set_metric_value(cache_data, "embedding_statistics", result)

    return result


def extract_trajectory(
    dataset_path: str,
    sample_id: Optional[int] = None,
) -> np.ndarray:
    """Extract embedding trajectory from a dataset.

    Args:
        dataset_path: Path to the progressive embeddings dataset
        sample_id: Optional sample_id to filter. If None, uses first available sample for visualization,
                   but computes statistics across all samples in the dataset.

    Returns:
        Embeddings array of shape [n_stages, n_features]
    """
    ds = load_progressive_dataset(dataset_path)
    # Load all rows to compute statistics across all samples
    all_rows = filter_records(ds, sample_id=None, dataset_path=dataset_path, model_checkpoint=None, check_cache=True)

    if not all_rows:
        raise ValueError(f"No records found in {dataset_path}")

    # Group all rows by sample_id
    all_by_sid = collate_stages_by_sample(all_rows)

    # Now extract trajectory for visualization (use specified sample_id or first available)
    if sample_id is not None:
        if sample_id not in all_by_sid:
            raise ValueError(f"Sample {sample_id} not found in {dataset_path}")
        vis_sample_id = sample_id
    else:
        # Use first available sample
        first_sid = sorted(all_by_sid.keys())[0]
        vis_sample_id = first_sid

    # Reload dataset for visualization sample if embeddings were removed
    # Check if we have embeddings in the stages
    stages = all_by_sid[vis_sample_id]
    has_embeddings_for_vis = any("embedding" in stage and stage.get("embedding") is not None for stage in stages)

    if not has_embeddings_for_vis:
        # Reload dataset with embeddings for visualization
        ds_vis = load_progressive_dataset(dataset_path)
        vis_rows = filter_records(ds_vis, sample_id=vis_sample_id, check_cache=False)
        stages = collate_stages_by_sample(vis_rows).get(vis_sample_id, [])

    # Extract embeddings in order for visualization
    embeddings = []
    for stage in stages:
        if "embedding" not in stage or stage.get("embedding") is None:
            raise ValueError(f"Embeddings not available for sample {vis_sample_id} in {dataset_path}")
        emb = flatten_embedding(stage)
        embeddings.append(emb)

    if len(embeddings) == 0:
        raise ValueError(f"No embeddings found for sample {sample_id} in {dataset_path}")

    X = np.stack(embeddings, axis=0)
    return X


def plot_pca_trajectories(
    trajectories: List[np.ndarray],
    checkpoint_names: List[str],
    outfile: str,
    n_components: int = 2,
    show_labels: bool = False,
    labels_list: Optional[List[List[str]]] = None,
):
    """Plot multiple embedding trajectories on a single PCA plot.

    Args:
        trajectories: List of embedding arrays, each of shape [n_stages, n_features]
        checkpoint_names: List of names for each trajectory (for legend)
        outfile: Output file path
        n_components: Number of PCA components to use (2 or 4)
        show_labels: Whether to show stage labels on points
        labels_list: Optional list of label lists for each trajectory
    """
    if len(trajectories) == 0:
        raise ValueError("No trajectories provided")

    # Combine all embeddings to fit a single PCA
    all_embeddings = np.vstack(trajectories)
    n_samples, n_features = all_embeddings.shape

    if n_samples < 2 or n_features < 2:
        raise ValueError(f"Insufficient data: {n_samples} samples, {n_features} features")

    n_components = min(n_components, n_samples - 1, n_features)
    if n_components < 2:
        raise ValueError(f"Cannot compute {n_components} components")

    # Fit PCA on all embeddings
    pca = PCA(n_components=n_components, random_state=42)
    pca.fit(all_embeddings)
    explained_var = pca.explained_variance_ratio_

    # Transform each trajectory
    transformed_trajectories = []
    for traj in trajectories:
        traj_transformed = pca.transform(traj)
        transformed_trajectories.append(traj_transformed)

    # Create distinct colors for checkpoints
    # Use a predefined set of highly distinct colors with maximum hue separation
    distinct_colors = [
        "#E6194B",  # bright red
        "#3CB44B",  # bright green
        "#FFE119",  # bright yellow
        "#4363D8",  # bright blue
        "#F58231",  # bright orange
        "#911EB4",  # bright purple
        "#42D4F4",  # bright cyan
        "#F032E6",  # bright magenta
        "#BFEF45",  # lime green
        "#FABED4",  # light pink
        "#469990",  # teal
        "#DCBEFF",  # light purple
        "#9A6324",  # brown
        "#FFFAC8",  # beige
        "#800000",  # maroon
        "#000075",  # navy
        "#A9A9A9",  # gray
        "#000000",  # black
    ]
    # Cycle through distinct colors if we have more trajectories than colors
    colors = [distinct_colors[i % len(distinct_colors)] for i in range(len(trajectories))]

    if n_components == 2:
        # Single 2D plot
        plt.figure(figsize=(10, 8))
        legend_handles = []
        for idx, (traj_transformed, name, color) in enumerate(zip(transformed_trajectories, checkpoint_names, colors)):
            x_data = traj_transformed[:, 0]
            y_data = traj_transformed[:, 1]

            # Plot trajectory line (without label)
            # plt.plot(x_data, y_data, color=color, alpha=0.1, linewidth=1.5, linestyle="--")

            # Plot points
            plt.scatter(x_data, y_data, c=[color], s=30, alpha=0.2, linewidths=0.5)

            # Create legend handle with scatter marker
            legend_handles.append(plt.scatter([], [], c=color, s=60, alpha=0.7, edgecolors="black", linewidths=0.5, label=name))

            # Add labels if requested
            if show_labels and labels_list is not None and idx < len(labels_list):
                labels = labels_list[idx]
                labeled_positions = []
                for k, lab in enumerate(labels):
                    if k >= len(x_data):
                        continue
                    # Check if there's already a labeled point within distance < 0.5
                    should_label = True
                    for labeled_pos in labeled_positions:
                        dist = np.linalg.norm([x_data[k] - labeled_pos[0], y_data[k] - labeled_pos[1]])
                        if dist < 0.5:
                            should_label = False
                            break
                    if should_label:
                        plt.text(x_data[k], y_data[k], lab, fontsize=12, ha="left", va="bottom", color=color)
                        labeled_positions.append([x_data[k], y_data[k]])

            # Mark start and end points
            if len(x_data) > 0:
                plt.scatter(x_data[0], y_data[0], c=[color], s=150, marker="o", edgecolors="black", linewidths=2, zorder=5)
                plt.scatter(x_data[-1], y_data[-1], c=[color], s=150, marker="s", edgecolors="black", linewidths=2, zorder=5)

        plt.xlabel(f"PC1 ({explained_var[0]:.4f})", fontsize=18)
        plt.ylabel(f"PC2 ({explained_var[1]:.4f})", fontsize=18)
        plt.title(
            f"PCA Trajectories Comparison\nCumulative variance: {explained_var.sum():.4f}",
            fontsize=20,
        )
        plt.legend(handles=legend_handles, loc="best", fontsize=18)
        plt.grid(True, alpha=0.3)
        plt.axis("equal")
        plt.tight_layout()
        saved_paths = save_figure_pdf_png(outfile, dpi=300)
        plt.close()
        print(f"Saved 2D PCA plot to: {', '.join(saved_paths)}")

    elif n_components == 4:
        # Multiple subplots for 4 components (similar to plot_pca_4_components)
        pairs = [(i, j) for i in range(n_components) for j in range(i + 1, n_components)]
        n_pairs = len(pairs)

        n_cols = 3
        n_rows = 2

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 6))
        if n_pairs == 1:
            axes = [axes]
        else:
            axes = axes.flatten()

        legend_handles = []
        for pair_idx, (i, j) in enumerate(pairs):
            ax = axes[pair_idx]

            for idx, (traj_transformed, name, color) in enumerate(zip(transformed_trajectories, checkpoint_names, colors)):
                x_data = traj_transformed[:, i]
                y_data = traj_transformed[:, j]

                # Plot trajectory line (without label)
                # ax.plot(x_data, y_data, color=color, alpha=0.5, linewidth=1.5, linestyle="--")

                # Plot points
                ax.scatter(x_data, y_data, c=[color], s=60, alpha=0.3, linewidths=0, edgecolors="none")

                # Create legend handle with scatter marker (only for first subplot)

                legend_name = name
                legend_name = legend_name.replace("sl_256_Meta-Llama-3.1-8B_ds_pg19_limit_1", "Base")
                legend_name = legend_name.replace("limit_1", "")
                legend_name = legend_name.replace("_lr_0.1", "")
                legend_name = legend_name.replace("sl_256_Meta-Llama-3.1-8B_", "")
                legend_name = legend_name.replace("ds_pg19-model-sampled-llama3.1-8B-prefix-64-max_len-2048_", "Sampled")
                legend_name = legend_name.replace("ds_pg19-lowercased-partial-64_", "Lowercase")
                legend_name = legend_name.replace("ds_pg19-random-suffix-shuffle-64_", "Random")

                if pair_idx == 0:
                    legend_handles.append(
                        ax.scatter([], [], c=color, s=60, alpha=0.7, linewidths=0, edgecolors="none", label=legend_name)
                    )

                # Mark start and end points
                if len(x_data) > 0:
                    ax.scatter(x_data[0], y_data[0], c=[color], s=150, marker="o", edgecolors="black", linewidths=2, zorder=5)
                    ax.scatter(x_data[-1], y_data[-1], c=[color], s=150, marker="s", edgecolors="black", linewidths=2, zorder=5)

            ax.set_xlabel(f"PC{i+1} ({explained_var[i]:.3f})", fontsize=14)
            ax.set_ylabel(f"PC{j+1} ({explained_var[j]:.3f})", fontsize=14)
            ax.set_title(f"PC{i+1} vs PC{j+1}", fontsize=16)
            ax.grid(True, alpha=0.3)
            ax.axis("equal")
            if pair_idx == 0:
                ax.legend(handles=legend_handles, loc="best", fontsize=14)

        # Hide unused subplots
        for idx in range(n_pairs, len(axes)):
            axes[idx].axis("off")

        plt.tight_layout()
        saved_paths = save_figure_pdf_png(outfile, dpi=300)
        plt.close()
        print(f"Saved 4-component PCA plot to: {', '.join(saved_paths)}")
    else:
        raise ValueError(f"n_components must be 2 or 4, got {n_components}")


def compute_pairwise_distances(final_embeddings: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute pairwise distances between final embeddings.

    Args:
        final_embeddings: List of final embedding arrays

    Returns:
        Tuple of (l2_distances, l1_distances, cosine_distances) matrices
    """
    n = len(final_embeddings)
    if n < 2:
        return np.array([]), np.array([]), np.array([])

    # Stack embeddings
    X = np.stack(final_embeddings, axis=0)  # [n_experiments, n_features]

    # Compute L2 distances
    diffs = X[:, None, :] - X[None, :, :]
    l2_distances = np.linalg.norm(diffs, axis=-1)

    # Compute L1 distances
    l1_distances = np.linalg.norm(diffs, ord=1, axis=-1)

    # Compute cosine distances
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-12)
    cos_sim = (Xn @ Xn.T).clip(-1.0, 1.0)
    cosine_distances = 1.0 - cos_sim

    return l2_distances, l1_distances, cosine_distances


def print_statistics_table(
    checkpoint_names: List[str],
    statistics: List[Dict[str, Any]],
    tablefmt: str = "grid",
):
    """Print a statistics table using tabulate.

    Args:
        checkpoint_names: List of experiment labels
        statistics: List of statistics dicts, each containing 'num_embeddings' and 'total_steps'
    """
    if len(checkpoint_names) == 0 or len(statistics) == 0:
        return

    # Prepare table data
    table_data = []
    for name, stats in zip(checkpoint_names, statistics):
        table_data.append(
            [
                name,
                stats.get("num_embeddings", "nan"),
                stats.get("trajectory_length", "nan"),
                stats.get("steps_taken", "nan"),
                stats.get("num_pca_for99_var", "nan"),
                stats.get("num_pca_for99_var_all_embeds", "nan"),
                stats.get("num_random_projections_for99_var", "nan"),
                stats.get("information_gain", "nan"),
                stats.get("information_gain_from_dataset", "nan"),
                stats.get("embedding_statistics", "nan"),
            ]
        )

    headers = [
        "Experiment",
        "# Compr. Tok",
        "Traj. Len",
        "Steps Taken",
        "PCA 99%",
        "PCA ALL 99%",
        "Rand. Proj. 99%",
        "Info Gain",
        "Info Gain (Dataset)",
        "Emb. Stats (Comp/Vocab)",
    ]
    table = tabulate(table_data, headers=headers, tablefmt=tablefmt, numalign="right", stralign="left")

    print("\n" + "=" * 80)
    print("Progressive Embeddings Statistics")
    print("=" * 80)
    print(table)
    print("=" * 80 + "\n")


def print_pairwise_distances_table(
    checkpoint_names: List[str],
    l2_distances: np.ndarray,
    l1_distances: np.ndarray,
    cosine_distances: np.ndarray,
    tablefmt: str = "grid",
):
    """Print pairwise distances tables using tabulate.

    Args:
        checkpoint_names: List of experiment labels
        l2_distances: L2 distance matrix [n_experiments, n_experiments]
        l1_distances: L1 distance matrix [n_experiments, n_experiments]
        cosine_distances: Cosine distance matrix [n_experiments, n_experiments]
    """
    if len(checkpoint_names) < 2 or l2_distances.size == 0:
        return

    n = len(checkpoint_names)

    # L2 distances table
    print("\n" + "=" * 80)
    print("Pairwise L2 Distances Between Final Embeddings")
    print("=" * 80)
    l2_table_data = []
    for i in range(n):
        row = [checkpoint_names[i]]
        for j in range(n):
            if i == j:
                row.append("0.000")
            else:
                row.append(f"{l2_distances[i, j]:.4f}")
        l2_table_data.append(row)
    l2_headers = ["Experiment"] + checkpoint_names
    l2_table = tabulate(l2_table_data, headers=l2_headers, tablefmt=tablefmt, numalign="right", stralign="left")
    print(l2_table)

    # L1 distances table
    print("\n" + "=" * 80)
    print("Pairwise L1 Distances Between Final Embeddings")
    print("=" * 80)
    l1_table_data = []
    for i in range(n):
        row = [checkpoint_names[i]]
        for j in range(n):
            if i == j:
                row.append("0.000")
            else:
                row.append(f"{l1_distances[i, j]:.4f}")
        l1_table_data.append(row)
    l1_headers = ["Experiment"] + checkpoint_names
    l1_table = tabulate(l1_table_data, headers=l1_headers, tablefmt=tablefmt, numalign="right", stralign="left")
    print(l1_table)

    # Cosine distances table
    print("\n" + "=" * 80)
    print("Pairwise Cosine Distances Between Final Embeddings")
    print("=" * 80)
    cos_table_data = []
    for i in range(n):
        row = [checkpoint_names[i]]
        for j in range(n):
            if i == j:
                row.append("0.000")
            else:
                row.append(f"{cosine_distances[i, j]:.4f}")
        cos_table_data.append(row)
    cos_headers = ["Experiment"] + checkpoint_names
    cos_table = tabulate(cos_table_data, headers=cos_headers, tablefmt=tablefmt, numalign="right", stralign="left")
    print(cos_table)
    print("=" * 80 + "\n")


def parse_names_mapping(names_str: Optional[str]) -> Tuple[Dict[str, str], Optional[List[str]]]:
    """Parse names mapping from string.

    Supports two formats:
    1. Path-based: 'path1:name1,path2:name2' (returns dict, None)
    2. Positional list: 'name1,name2,name3' (returns empty dict, list of names)

    Returns:
        Tuple of (path_mapping_dict, positional_names_list)
    """
    if names_str is None:
        return {}, None

    # Check if it contains colons (path-based mapping)
    if ":" in names_str:
        mapping = {}
        for pair in names_str.split(","):
            if ":" in pair:
                key, value = pair.split(":", 1)
                mapping[key.strip()] = value.strip()
        return mapping, None
    else:
        # Positional list format
        names = [name.strip() for name in names_str.split(",") if name.strip()]
        return {}, names if names else None


def main():
    parser = argparse.ArgumentParser(
        description="Visualize multiple progressive embeddings training trajectories on one PCA plot"
    )
    parser.add_argument(
        "--checkpoints",
        type=str,
        nargs="+",
        required=True,
        help="Paths to progressive embeddings datasets (checkpoints)",
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output file path for the plot",
    )
    parser.add_argument(
        "--sample_id",
        type=int,
        default=None,
        help="Sample ID to visualize (default: first available sample)",
    )
    parser.add_argument(
        "--n_components",
        type=int,
        default=2,
        choices=[2, 4],
        help="Number of PCA components (2 or 4)",
    )
    parser.add_argument(
        "--show_labels",
        action="store_true",
        help="Show stage labels on trajectory points",
    )
    parser.add_argument(
        "--names_mapping",
        type=str,
        default=None,
        help="Optional mapping of checkpoint paths to display names. "
        "Two formats supported: 1) Path-based: 'path1:name1,path2:name2' "
        "2) Positional list: 'name1,name2,name3' (corresponds to --checkpoints order)",
    )
    parser.add_argument(
        "--tablefmt",
        type=str,
        default="grid",
        help="Tabulate table format for printed statistics (e.g., grid, simple, github). Default: grid.",
    )

    args = parser.parse_args()

    # Parse names mapping
    path_mapping, positional_names = parse_names_mapping(args.names_mapping)

    # Validate positional names length if provided
    if positional_names is not None and len(positional_names) != len(args.checkpoints):
        raise ValueError(
            f"Number of names in --names_mapping ({len(positional_names)}) "
            f"does not match number of checkpoints ({len(args.checkpoints)})"
        )

    not_exists_checkpoints = []
    for checkpoint in args.checkpoints:
        if not os.path.isdir(checkpoint):
            not_exists_checkpoints.append(checkpoint)
    assert len(not_exists_checkpoints) == 0, f"checkpoints not exists: {not_exists_checkpoints}"

    # Extract trajectories from each checkpoint
    trajectories = []
    checkpoint_names = []
    for idx, checkpoint_path in tqdm(enumerate(args.checkpoints)):
        traj = extract_trajectory(checkpoint_path, sample_id=args.sample_id)
        trajectories.append(traj)

        # Determine name for this checkpoint
        if positional_names is not None:
            # Use positional mapping
            checkpoint_names.append(positional_names[idx])
        elif checkpoint_path in path_mapping:
            # Use path-based mapping
            checkpoint_names.append(path_mapping[checkpoint_path])
        else:
            # Extract a short name from the path
            name = os.path.basename(os.path.dirname(checkpoint_path))
            if not name or name == ".":
                name = os.path.basename(checkpoint_path)
            checkpoint_names.append(name)

        print(f"Loaded trajectory from {checkpoint_path}: {traj.shape[0]} stages, {traj.shape[1]} features")

    if len(trajectories) == 0:
        raise ValueError("No valid trajectories loaded")

    # Create output directory if needed
    os.makedirs(os.path.dirname(args.output) if os.path.dirname(args.output) else ".", exist_ok=True)

    # Plot trajectories
    plot_pca_trajectories(
        trajectories=trajectories,
        checkpoint_names=checkpoint_names,
        outfile=args.output,
        n_components=args.n_components,
        show_labels=args.show_labels,
        labels_list=None,
    )

    root, ext = os.path.splitext(args.output)
    ext = ext.lower()
    if ext in {".pdf", ".png"}:
        print(f"Visualization complete. Saved to: {args.output} (and {root + ('.png' if ext == '.pdf' else '.pdf')})")
    else:
        print(f"Visualization complete. Saved to: {args.output}")


if __name__ == "__main__":
    main()
