"""
BW-DAM Retrieval on Gaussian Word Embeddings (Word2Gauss).

This experiment:
1. Trains Word2Gauss on the text8 corpus to obtain Gaussian word embeddings
2. Perturbs word embeddings and runs BW-DAM retrieval dynamics
3. Measures retrieval accuracy across different temperature (β) values
4. Visualizes word evolution trajectories during retrieval

Requirements:
    - word2gauss (installed from https://github.com/seomoz/word2gauss)
    - text8 dataset (downloaded automatically)

Note: This script is designed to run in Google Colab. Some paths may need
adjustment for local execution.
"""

import os
import sys
import time
import random
import pickle
import logging
from collections import Counter
from itertools import islice
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class Word2GaussConfig:
    """Configuration for Word2Gauss training."""
    
    # Dataset
    dataset: str = "text8"  # "text8" or "text8_small"
    
    # Model hyperparameters
    embedding_dim: int = 50
    covariance_type: str = "diagonal"  # "diagonal" or "spherical"
    energy_type: str = "KL"  # "KL" or "IP"
    min_count: int = 100
    
    # Training parameters
    window_size: int = 5
    n_samples: int = 2  # Negative samples per positive
    batch_size: int = 20
    n_workers: int = 2
    chunk_size: int = 1000
    report_interval: int = 100
    
    # Paths
    model_path: str = "word2gauss_model.tar.gz"


@dataclass
class BWDAMConfig:
    """Configuration for BW-DAM retrieval experiment."""
    
    # Temperature values to test
    beta_values: np.ndarray = None
    beta_trajectory_values: List[float] = None
    
    # Dynamics parameters
    epsilon: float = 0.001  # Convergence threshold
    max_iterations: int = 100
    
    # Display parameters
    num_display_words: int = 5
    num_iterations_to_display: int = 10
    
    # Random seed
    seed: int = 42
    
    def __post_init__(self):
        if self.beta_values is None:
            self.beta_values = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100])
        if self.beta_trajectory_values is None:
            self.beta_trajectory_values = [1, 10, 50]


# =============================================================================
# Vocabulary
# =============================================================================

class SimpleVocabulary:
    """Simple vocabulary class for word2gauss."""
    
    LARGEST_UINT32 = 4294967295
    
    def __init__(self, word2id_dict: Dict[str, int]):
        self._word2id = word2id_dict
        self._id2word = {i: w for w, i in word2id_dict.items()}
        self._n = len(word2id_dict)
    
    def __len__(self) -> int:
        return self._n
    
    def word2id(self, word: str) -> int:
        return self._word2id[word]
    
    def id2word(self, i: int) -> str:
        return self._id2word[i]
    
    def tokenize_ids(self, text: str, remove_oov: bool = False) -> np.ndarray:
        tokens = text.strip().split()
        if remove_oov:
            return np.array(
                [self._word2id[t] for t in tokens if t in self._word2id],
                dtype=np.uint32
            )
        result = np.zeros(len(tokens), dtype=np.uint32)
        for i, t in enumerate(tokens):
            result[i] = self._word2id.get(t, self.LARGEST_UINT32)
        return result
    
    def random_ids(self, num: int) -> np.ndarray:
        return np.random.randint(0, self._n, size=num).astype(np.uint32)


# =============================================================================
# BW-DAM for Diagonal Covariances
# =============================================================================

def w2_squared_diagonal(
    m1: np.ndarray,
    sigma1: np.ndarray,
    m2: np.ndarray,
    sigma2: np.ndarray,
) -> float:
    """Compute squared W₂ distance between two diagonal Gaussians."""
    mean_term = np.sum((m1 - m2) ** 2)
    cov_term = np.sum((np.sqrt(sigma1) - np.sqrt(sigma2)) ** 2)
    return mean_term + cov_term


def w2_distance_diagonal(
    m1: np.ndarray,
    sigma1: np.ndarray,
    m2: np.ndarray,
    sigma2: np.ndarray,
) -> float:
    """Compute W₂ distance between two diagonal Gaussians."""
    return np.sqrt(w2_squared_diagonal(m1, sigma1, m2, sigma2))


def compute_all_distances(
    m: np.ndarray,
    omega: np.ndarray,
    mu_all: np.ndarray,
    sigma_all: np.ndarray,
) -> np.ndarray:
    """Compute W₂² distances from query to all stored patterns."""
    mean_diffs = mu_all - m
    mean_terms = np.sum(mean_diffs ** 2, axis=1)
    sqrt_sigma = np.sqrt(sigma_all)
    sqrt_omega = np.sqrt(omega)
    cov_terms = np.sum((sqrt_sigma - sqrt_omega) ** 2, axis=1)
    return mean_terms + cov_terms


def softmax_weights(distances: np.ndarray, beta: float) -> np.ndarray:
    """Compute softmax weights with numerical stability."""
    neg_beta_d = -beta * distances
    neg_beta_d = neg_beta_d - np.max(neg_beta_d)
    exp_vals = np.exp(neg_beta_d)
    return exp_vals / np.sum(exp_vals)


def bwdam_step_diagonal(
    m: np.ndarray,
    omega: np.ndarray,
    mu_all: np.ndarray,
    sigma_all: np.ndarray,
    beta: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """One step of BW-DAM update for diagonal covariances."""
    distances = compute_all_distances(m, omega, mu_all, sigma_all)
    weights = softmax_weights(distances, beta)
    
    # Mean update
    m_new = np.sum(weights[:, np.newaxis] * mu_all, axis=0)
    
    # Covariance update (diagonal case simplifies significantly)
    sqrt_sigma = np.sqrt(sigma_all)
    weighted_sqrt_sigma = np.sum(weights[:, np.newaxis] * sqrt_sigma, axis=0)
    omega_new = weighted_sqrt_sigma ** 2
    
    return m_new, omega_new


def run_bwdam_diagonal(
    m_init: np.ndarray,
    omega_init: np.ndarray,
    mu_all: np.ndarray,
    sigma_all: np.ndarray,
    beta: float,
    epsilon: float = 0.001,
    max_iter: int = 100,
    track_trajectory: bool = False,
) -> Tuple[np.ndarray, np.ndarray, int, List[int]]:
    """Run BW-DAM until convergence."""
    m, omega = m_init.copy(), omega_init.copy()
    trajectory = []
    
    for t in range(max_iter):
        if track_trajectory:
            distances = compute_all_distances(m, omega, mu_all, sigma_all)
            trajectory.append(int(np.argmin(distances)))
        
        m_new, omega_new = bwdam_step_diagonal(m, omega, mu_all, sigma_all, beta)
        w2_dist = w2_distance_diagonal(m, omega, m_new, omega_new)
        m, omega = m_new, omega_new
        
        if w2_dist < epsilon:
            break
    
    if track_trajectory:
        distances = compute_all_distances(m, omega, mu_all, sigma_all)
        trajectory.append(int(np.argmin(distances)))
    
    return m, omega, t + 1, trajectory


def perturb_gaussian_diagonal(
    mu: np.ndarray,
    sigma: np.ndarray,
    perturbation_distance: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """Perturb a diagonal Gaussian by moving the mean in a random direction."""
    direction = np.random.randn(len(mu))
    direction = direction / np.linalg.norm(direction)
    m_perturbed = mu + perturbation_distance * direction
    omega_perturbed = sigma.copy()
    return m_perturbed, omega_perturbed


# =============================================================================
# Experiment Runner
# =============================================================================

def run_retrieval_experiment(
    mu_all: np.ndarray,
    sigma_all: np.ndarray,
    vocab: SimpleVocabulary,
    config: BWDAMConfig,
) -> Dict:
    """
    Run BW-DAM retrieval experiment across multiple β values.
    
    Returns:
        Dictionary with 'beta', 'accuracy', and 'trajectories' keys.
    """
    np.random.seed(config.seed)
    
    N, K = mu_all.shape
    lambda_min = np.min(sigma_all)
    perturbation_dist = np.sqrt(lambda_min)
    
    print("=" * 60)
    print("BW-DAM Retrieval Experiment on Word Embeddings")
    print("=" * 60)
    print(f"\nConfiguration:")
    print(f"  Number of words (N):      {N}")
    print(f"  Embedding dimension (K):  {K}")
    print(f"  λ_min:                    {lambda_min:.6f}")
    print(f"  Perturbation distance:    {perturbation_dist:.6f}")
    print(f"  Convergence threshold:    {config.epsilon}")
    
    # Select random words for trajectory tracking
    random_indices = np.random.choice(N, config.num_display_words, replace=False)
    random_words = [vocab.id2word(i) for i in random_indices]
    print(f"\nWords for trajectory tracking: {random_words}")
    
    # Pre-generate perturbations
    perturbations = {}
    for i in random_indices:
        m_init, omega_init = perturb_gaussian_diagonal(
            mu_all[i], sigma_all[i], perturbation_dist
        )
        perturbations[i] = (m_init, omega_init)
    
    # Run experiment
    results = {
        "beta": [],
        "accuracy": [],
        "trajectories": {beta: {} for beta in config.beta_trajectory_values}
    }
    
    print(f"\n{'β':<12} {'Accuracy (%)':<15}")
    print("-" * 27)
    
    for beta in config.beta_values:
        n_success = 0
        track_beta = beta in config.beta_trajectory_values
        
        for i in range(N):
            # Use pre-generated perturbation for tracked words
            if i in random_indices:
                m_init, omega_init = perturbations[i]
            else:
                m_init, omega_init = perturb_gaussian_diagonal(
                    mu_all[i], sigma_all[i], perturbation_dist
                )
            
            track_this = track_beta and (i in random_indices)
            
            m_final, omega_final, n_iters, trajectory = run_bwdam_diagonal(
                m_init, omega_init, mu_all, sigma_all, beta,
                epsilon=config.epsilon,
                max_iter=config.max_iterations,
                track_trajectory=track_this
            )
            
            # Check retrieval
            distances_final = compute_all_distances(
                m_final, omega_final, mu_all, sigma_all
            )
            closest_idx = int(np.argmin(distances_final))
            
            if closest_idx == i:
                n_success += 1
            
            if track_this:
                word = vocab.id2word(i)
                results["trajectories"][beta][word] = {
                    "indices": trajectory,
                    "words": [vocab.id2word(idx) for idx in trajectory]
                }
        
        accuracy = 100.0 * n_success / N
        results["beta"].append(beta)
        results["accuracy"].append(accuracy)
        print(f"{beta:<12.4f} {accuracy:<15.2f}")
    
    print("\n✓ Experiment complete!")
    return results


# =============================================================================
# Visualization
# =============================================================================

def plot_results(
    results: Dict,
    config: BWDAMConfig,
    random_words: List[str],
    output_path: str = "bwdam_words_results.png",
) -> None:
    """Generate plots: accuracy vs β and word evolution tables."""
    n_tables = len(config.beta_trajectory_values)
    fig = plt.figure(figsize=(14, 4 + 2.5 * n_tables))
    gs = fig.add_gridspec(1 + n_tables, 1, height_ratios=[3] + [1.5] * n_tables, hspace=0.4)
    
    # Plot: Beta vs Accuracy
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.semilogx(results["beta"], results["accuracy"], "b-o", linewidth=2, markersize=6)
    ax1.set_xlabel(r"$\beta$ (inverse temperature)", fontsize=12)
    ax1.set_ylabel("Retrieval Accuracy (%)", fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 105])
    
    # Tables: Word evolution at different β values
    for table_idx, beta in enumerate(config.beta_trajectory_values):
        ax = fig.add_subplot(gs[1 + table_idx, 0])
        ax.axis("off")
        
        trajectories = results["trajectories"][beta]
        col_labels = ["Original"] + [f"Iter {i+1}" for i in range(config.num_iterations_to_display - 1)]
        
        table_data = []
        for word in random_words:
            if word in trajectories:
                traj = trajectories[word]["words"]
                row = [word]
                for i in range(config.num_iterations_to_display - 1):
                    if i < len(traj):
                        row.append(traj[i])
                    else:
                        row.append(traj[-1] if traj else word)
                table_data.append(row)
        
        if table_data:
            table = ax.table(
                cellText=table_data,
                colLabels=col_labels,
                loc="center",
                cellLoc="center"
            )
            table.auto_set_font_size(False)
            table.set_fontsize(9)
            table.scale(1.2, 1.5)
            
            for j in range(len(col_labels)):
                table[(0, j)].set_facecolor("#E6E6E6")
                table[(0, j)].set_text_props(fontweight="bold")
        
        ax.set_title(f"Word evolution at β = {beta}", fontsize=11, pad=10)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"\n✓ Plot saved to {output_path}")


# =============================================================================
# Main
# =============================================================================

def main():
    """
    Main entry point.
    
    Note: This function assumes Word2Gauss has been trained and embeddings
    are available. For the full pipeline including training, run this
    script in Google Colab with the installation cells.
    """
    print("=" * 60)
    print("BW-DAM on Gaussian Word Embeddings")
    print("=" * 60)
    print("\nNote: This script requires pre-trained Word2Gauss embeddings.")
    print("For training, run the full notebook in Google Colab.")
    print("\nTo use pre-trained embeddings:")
    print("  mu = np.load('mu_embeddings.npy')")
    print("  sigma = np.load('sigma_embeddings.npy')")
    print("  vocab = pickle.load(open('vocabulary.pkl', 'rb'))")


if __name__ == "__main__":
    main()
