"""
BW-DAM Retrieval on Gaussian Sentence Embeddings (GaussCSE).

This experiment:
1. Trains a GaussCSE model that maps sentences to Gaussian distributions
2. Stores sentence embeddings as patterns in BW-DAM
3. Perturbs embeddings and runs retrieval dynamics
4. Measures accuracy across different temperature (β) values
5. Visualizes qualitative retrieval examples

The model is based on contrastive learning with KL-divergence similarity
between Gaussian sentence representations.

Requirements:
    - transformers
    - datasets
    - torch
"""

from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm
import warnings

warnings.filterwarnings("ignore")


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class GaussCSEConfig:
    """Configuration for GaussCSE training."""
    
    model_name: str = "bert-base-uncased"
    batch_size: int = 32
    learning_rate: float = 3e-5
    num_epochs: int = 3
    temperature: float = 0.05
    max_seq_len: int = 64
    train_size: int = 10000


@dataclass
class BWDAMSentenceConfig:
    """Configuration for BW-DAM retrieval on sentences."""
    
    num_patterns: int = 1000
    num_samples: int = 100  # Samples per β value for accuracy
    num_trials: int = 5  # Trials for error bars
    epsilon: float = 0.001  # Convergence threshold
    max_steps: int = 100


# =============================================================================
# GaussCSE Model
# =============================================================================

@dataclass
class GaussOutput:
    """Output from GaussCSE model."""
    mu: torch.FloatTensor
    std: torch.FloatTensor


class GaussCSEModel(nn.Module):
    """
    Gaussian Contrastive Sentence Embedding model.
    
    Maps sentences to Gaussian distributions (μ, σ) where:
    - μ is the mean vector
    - σ is the standard deviation (diagonal covariance)
    """
    
    def __init__(self, model_name: str = "bert-base-uncased"):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.backbone.config.hidden_size
        self.w_mu = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_var = nn.Linear(self.hidden_size, self.hidden_size)
        self.activation = nn.Tanh()
    
    def forward(self, input_ids, attention_mask, **kwargs) -> GaussOutput:
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        emb = outputs.last_hidden_state[:, 0]  # [CLS] token
        mu = self.activation(self.w_mu(emb))
        log_var = self.w_var(emb)
        std = torch.sqrt(torch.exp(log_var) + 1e-8)
        return GaussOutput(mu=mu, std=std)


def asymmetrical_kl_sim_matrix(
    mu1: torch.Tensor,
    std1: torch.Tensor,
    mu2: torch.Tensor,
    std2: torch.Tensor,
) -> torch.Tensor:
    """Compute asymmetric KL-divergence similarity matrix."""
    var1 = std1 ** 2
    var2 = std2 ** 2
    
    mu1 = mu1.unsqueeze(1)
    var1 = var1.unsqueeze(1)
    mu2 = mu2.unsqueeze(0)
    var2 = var2.unsqueeze(0)
    
    kl = 0.5 * (torch.log(var2 / var1) + (var1 + (mu1 - mu2) ** 2) / var2 - 1).sum(dim=-1)
    sim = 1 / (1 + kl)
    return sim


# =============================================================================
# Training
# =============================================================================

def train_gausscse(
    config: GaussCSEConfig,
    device: torch.device,
) -> Tuple[GaussCSEModel, AutoTokenizer]:
    """Train GaussCSE model on NLI data."""
    
    print("Loading training data...")
    df = pd.read_csv(
        "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/"
        "resolve/main/nli_for_simcse.csv"
    )
    df_train = df.sample(n=config.train_size, random_state=42).reset_index(drop=True)
    
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    model = GaussCSEModel(config.model_name).to(device)
    
    def collate_fn(batch):
        sent0 = [b["sent0"] for b in batch]
        sent1 = [b["sent1"] for b in batch]
        hard_neg = [b["hard_neg"] for b in batch]
        
        enc0 = tokenizer(sent0, padding=True, truncation=True,
                        max_length=config.max_seq_len, return_tensors="pt")
        enc1 = tokenizer(sent1, padding=True, truncation=True,
                        max_length=config.max_seq_len, return_tensors="pt")
        enc_neg = tokenizer(hard_neg, padding=True, truncation=True,
                           max_length=config.max_seq_len, return_tensors="pt")
        
        return {"sent0": enc0, "sent1": enc1, "hard_neg": enc_neg}
    
    train_data = df_train.to_dict("records")
    train_loader = DataLoader(
        train_data, batch_size=config.batch_size, shuffle=True,
        collate_fn=collate_fn, num_workers=2, drop_last=True
    )
    
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=config.learning_rate, weight_decay=0.01
    )
    
    print("Training GaussCSE model...")
    model.train()
    for epoch in range(config.num_epochs):
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}", leave=False)
        for batch in pbar:
            sent0 = {k: v.to(device) for k, v in batch["sent0"].items()}
            sent1 = {k: v.to(device) for k, v in batch["sent1"].items()}
            hard_neg = {k: v.to(device) for k, v in batch["hard_neg"].items()}
            
            out0 = model(**sent0)
            out1 = model(**sent1)
            out_neg = model(**hard_neg)
            
            pos_mat = asymmetrical_kl_sim_matrix(out1.mu, out1.std, out0.mu, out0.std)
            rev_mat = asymmetrical_kl_sim_matrix(out0.mu, out0.std, out1.mu, out1.std)
            neg_mat = asymmetrical_kl_sim_matrix(out_neg.mu, out_neg.std, out0.mu, out0.std)
            
            sim_mat = torch.cat([pos_mat, rev_mat, neg_mat], dim=1) / config.temperature
            labels = torch.arange(sim_mat.size(0)).to(device)
            loss = F.cross_entropy(sim_mat, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    print("✓ Training complete!")
    return model, tokenizer


@torch.no_grad()
def encode_sentences(
    sentences: List[str],
    model: GaussCSEModel,
    tokenizer: AutoTokenizer,
    device: torch.device,
    batch_size: int = 64,
    max_seq_len: int = 64,
) -> GaussOutput:
    """Encode sentences to Gaussian embeddings."""
    model.eval()
    all_mu, all_std = [], []
    
    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i + batch_size]
        enc = tokenizer(
            batch, padding=True, truncation=True,
            max_length=max_seq_len, return_tensors="pt"
        )
        enc = {k: v.to(device) for k, v in enc.items()}
        out = model(**enc)
        all_mu.append(out.mu.cpu())
        all_std.append(out.std.cpu())
    
    return GaussOutput(mu=torch.cat(all_mu, dim=0), std=torch.cat(all_std, dim=0))


# =============================================================================
# BW-DAM for Diagonal Covariances
# =============================================================================

class BWDAMRetrieval:
    """BW-DAM retrieval for diagonal Gaussian patterns."""
    
    def __init__(
        self,
        mu: torch.Tensor,
        std: torch.Tensor,
        beta: float = 1.0,
        eps: float = 1e-8,
    ):
        self.mu = mu
        self.std = std
        self.var = std ** 2
        self.beta = beta
        self.eps = eps
        self.N, self.D = mu.shape
    
    def w2_squared(self, m: torch.Tensor, omega_std: torch.Tensor) -> torch.Tensor:
        """Compute W₂² to all stored patterns."""
        mean_term = ((self.mu - m) ** 2).sum(dim=-1)
        cov_term = ((self.std - omega_std) ** 2).sum(dim=-1)
        return mean_term + cov_term
    
    def state_distance_sq(
        self,
        m1: torch.Tensor,
        std1: torch.Tensor,
        m2: torch.Tensor,
        std2: torch.Tensor,
    ) -> float:
        """Compute W₂² between two states."""
        mean_term = ((m1 - m2) ** 2).sum().item()
        cov_term = ((std1 - std2) ** 2).sum().item()
        return mean_term + cov_term
    
    def step(
        self,
        m: torch.Tensor,
        omega_std: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """One step of BW-DAM dynamics."""
        D = self.w2_squared(m, omega_std)
        w = F.softmax(-self.beta * D, dim=0)
        
        # Transport coefficients for diagonal case
        A = self.std / (omega_std + self.eps)
        
        # Weighted updates
        m_new = (w.unsqueeze(-1) * self.mu).sum(dim=0)
        A_tilde = (w.unsqueeze(-1) * A).sum(dim=0)
        omega_var_new = (A_tilde ** 2) * (omega_std ** 2)
        omega_std_new = torch.sqrt(omega_var_new + self.eps)
        
        return m_new, omega_std_new, w
    
    def retrieve_until_convergence(
        self,
        m_init: torch.Tensor,
        omega_std_init: torch.Tensor,
        epsilon: float = 0.001,
        max_steps: int = 100,
    ) -> List[Dict]:
        """Run BW-DAM until convergence."""
        m, omega_std = m_init.clone(), omega_std_init.clone()
        trajectory = []
        
        D = self.w2_squared(m, omega_std)
        trajectory.append({
            "step": 0,
            "m": m.clone(),
            "omega_std": omega_std.clone(),
            "nearest_idx": D.argmin().item(),
        })
        
        for step in range(1, max_steps + 1):
            m_new, omega_std_new, _ = self.step(m, omega_std)
            state_dist = np.sqrt(self.state_distance_sq(m, omega_std, m_new, omega_std_new))
            
            m, omega_std = m_new, omega_std_new
            D = self.w2_squared(m, omega_std)
            
            trajectory.append({
                "step": step,
                "m": m.clone(),
                "omega_std": omega_std.clone(),
                "nearest_idx": D.argmin().item(),
            })
            
            if state_dist < epsilon:
                break
        
        return trajectory


def perturb_at_w2_distance(
    mu: torch.Tensor,
    std: torch.Tensor,
    target_w2: float,
    seed: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Perturb a Gaussian to be at target W₂ distance."""
    if seed is not None:
        torch.manual_seed(seed)
    
    delta_mu = torch.randn_like(mu)
    delta_std = torch.randn_like(std)
    
    current_w2_sq = (delta_mu ** 2).sum() + (delta_std ** 2).sum()
    target_w2_sq = target_w2 ** 2
    
    scale = torch.sqrt(torch.tensor(target_w2_sq) / current_w2_sq)
    delta_mu = delta_mu * scale
    delta_std = delta_std * scale
    
    m_perturbed = mu + delta_mu
    std_perturbed = torch.abs(std + delta_std)
    
    return m_perturbed, std_perturbed


# =============================================================================
# Experiments
# =============================================================================

def run_single_retrieval(
    original_idx: int,
    pattern_embeddings: GaussOutput,
    lambda_min: float,
    beta: float,
    epsilon: float = 0.001,
    seed: Optional[int] = None,
) -> Dict:
    """Run single retrieval test."""
    original_mu = pattern_embeddings.mu[original_idx]
    original_std = pattern_embeddings.std[original_idx]
    
    target_w2 = np.sqrt(lambda_min)
    m_init, omega_init = perturb_at_w2_distance(original_mu, original_std, target_w2, seed)
    
    bw_dam = BWDAMRetrieval(
        mu=pattern_embeddings.mu,
        std=pattern_embeddings.std,
        beta=beta
    )
    
    trajectory = bw_dam.retrieve_until_convergence(m_init, omega_init, epsilon=epsilon)
    final_idx = trajectory[-1]["nearest_idx"]
    
    return {
        "original_idx": original_idx,
        "final_idx": final_idx,
        "recovered": final_idx == original_idx,
        "trajectory": trajectory,
        "num_steps": len(trajectory) - 1
    }


def run_accuracy_vs_beta(
    stored_sentences: List[str],
    pattern_embeddings: GaussOutput,
    lambda_min: float,
    beta_values: List[float],
    config: BWDAMSentenceConfig,
) -> Tuple[List[float], List[float]]:
    """Compute accuracy vs β with error bars."""
    means, stds = [], []
    
    for beta in tqdm(beta_values, desc="Testing β values"):
        trial_accuracies = []
        
        for trial in range(config.num_trials):
            np.random.seed(trial * 12345)
            sample_indices = np.random.choice(
                len(stored_sentences), size=config.num_samples, replace=False
            )
            
            num_recovered = 0
            for i, idx in enumerate(sample_indices):
                result = run_single_retrieval(
                    original_idx=idx,
                    pattern_embeddings=pattern_embeddings,
                    lambda_min=lambda_min,
                    beta=beta,
                    epsilon=config.epsilon,
                    seed=trial * 1000 + i
                )
                if result["recovered"]:
                    num_recovered += 1
            
            trial_accuracies.append(num_recovered / config.num_samples * 100)
        
        means.append(np.mean(trial_accuracies))
        stds.append(np.std(trial_accuracies))
    
    return means, stds


# =============================================================================
# Visualization
# =============================================================================

def plot_accuracy_vs_beta(
    beta_values: List[float],
    means: List[float],
    stds: List[float],
    output_path: str = "bwdam_sentences_accuracy.png",
) -> None:
    """Plot retrieval accuracy vs β."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    ax.errorbar(
        beta_values, means, yerr=stds,
        fmt="o-", capsize=4, capthick=1.5,
        markersize=6, linewidth=2,
        color="#2E7D32", ecolor="#2E7D32",
        label="BW-DAM"
    )
    
    ax.set_xscale("log")
    ax.set_xlabel(r"$\beta$ (inverse temperature)", fontsize=14)
    ax.set_ylabel("Retrieval Accuracy (%)", fontsize=14)
    ax.set_xlim([0.008, 150])
    ax.set_ylim([0, 105])
    ax.grid(True, alpha=0.3)
    ax.tick_params(axis="both", which="major", labelsize=12)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
    print(f"Saved: {output_path}")


def plot_qualitative_example(
    example: Dict,
    output_path: str = "bwdam_sentences_qualitative.png",
) -> None:
    """Plot qualitative retrieval example as a table."""
    fig, ax = plt.subplots(figsize=(14, 3.5))
    ax.axis("off")
    
    trajectory = example["trajectory"]
    sentences = example["sentences_along_trajectory"]
    original_idx = example["original_idx"]
    original_sentence = example["original_sentence"]
    
    # Build display data (collapse duplicates)
    display_data = []
    prev_sent = None
    for i, t in enumerate(trajectory):
        sent = sentences[i]
        idx = t["nearest_idx"]
        if sent != prev_sent:
            display_data.append({
                "step": i,
                "sentence": sent,
                "is_original": idx == original_idx
            })
            prev_sent = sent
    
    if len(display_data) > 6:
        display_data = display_data[:5] + [display_data[-1]]
    
    # Create table
    row_labels = ["Original"]
    cell_text = [[original_sentence[:70] + "..." if len(original_sentence) > 70 else original_sentence]]
    cell_colors = [["#c8e6c9"]]
    
    for item in display_data:
        step = item["step"]
        sent = item["sentence"]
        truncated = sent[:70] + "..." if len(sent) > 70 else sent
        
        row_labels.append(f"Step {step}" if step > 0 else "Step 0 (perturbed)")
        cell_text.append([truncated])
        cell_colors.append(["#c8e6c9" if item["is_original"] else "#ffffff"])
    
    table = ax.table(
        cellText=cell_text,
        rowLabels=row_labels,
        cellLoc="left",
        rowLoc="center",
        loc="center",
        cellColours=cell_colors
    )
    
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1.0, 1.8)
    
    for key, cell in table.get_celld().items():
        cell.set_edgecolor("#cccccc")
        cell.set_linewidth(0.5)
        if key[1] == -1:
            cell.set_text_props(fontweight="bold", fontsize=10)
            cell.set_facecolor("#f5f5f5")
            cell.set_width(0.15)
        else:
            cell.set_width(0.85)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
    print(f"Saved: {output_path}")


# =============================================================================
# Main
# =============================================================================

def main():
    """Run the full sentence retrieval experiment."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Train GaussCSE
    train_config = GaussCSEConfig()
    model, tokenizer = train_gausscse(train_config, device)
    
    # Prepare patterns
    bwdam_config = BWDAMSentenceConfig()
    
    print(f"\nPreparing {bwdam_config.num_patterns} stored patterns...")
    df = pd.read_csv(
        "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/"
        "resolve/main/nli_for_simcse.csv"
    )
    df_train = df.sample(n=train_config.train_size, random_state=42).reset_index(drop=True)
    
    unique_sentences = list(set(df_train["sent0"].tolist() + df_train["sent1"].tolist()))
    np.random.seed(42)
    stored_sentences = np.random.choice(
        unique_sentences, size=bwdam_config.num_patterns, replace=False
    ).tolist()
    
    pattern_embeddings = encode_sentences(
        stored_sentences, model, tokenizer, device,
        max_seq_len=train_config.max_seq_len
    )
    lambda_min = (pattern_embeddings.std ** 2).min().item()
    print(f"λ_min = {lambda_min:.6f}")
    
    # Run accuracy experiment
    beta_values = np.logspace(-2, 2, num=20).tolist()
    means, stds = run_accuracy_vs_beta(
        stored_sentences, pattern_embeddings, lambda_min,
        beta_values, bwdam_config
    )
    
    # Plot results
    plot_accuracy_vs_beta(beta_values, means, stds)
    plt.show()


if __name__ == "__main__":
    main()
