"""Simple TopK Sparse Autoencoder implementation."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path

from .config import SAEConfig


class TopKSAE(nn.Module):
    """Sparse Autoencoder with TopK activation."""

    def __init__(self, input_dim: int, dict_size: int, k: int):
        super().__init__()
        self.input_dim = input_dim
        self.dict_size = dict_size
        self.k = k

        # Encoder: input_dim -> dict_size
        self.encoder = nn.Linear(input_dim, dict_size, bias=True)
        # Decoder: dict_size -> input_dim (no bias, dictionary atoms)
        self.decoder = nn.Linear(dict_size, input_dim, bias=False)

        # Initialize decoder columns to unit norm
        with torch.no_grad():
            self.decoder.weight.data = F.normalize(self.decoder.weight.data, dim=0)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode input to sparse codes using TopK."""
        h = self.encoder(x)  # (batch, dict_size)
        # TopK: keep only top k activations per sample
        topk_vals, topk_idx = torch.topk(h, self.k, dim=-1)
        sparse = torch.zeros_like(h)
        sparse.scatter_(-1, topk_idx, topk_vals)
        return sparse

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode sparse codes to reconstruction."""
        return self.decoder(z)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass returning reconstruction and sparse codes."""
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, z

    def get_dictionary(self) -> torch.Tensor:
        """Return the dictionary matrix D (input_dim x dict_size).

        PyTorch Linear stores weights as (out_features, in_features).
        Decoder maps dict_size -> input_dim, so weight is (input_dim, dict_size).
        This is already the correct orientation for D.
        """
        return self.decoder.weight.data  # (input_dim, dict_size)


def train_sae(
    embeddings_path: Path,
    config: SAEConfig,
    output_path: Path,
    no_wandb: bool = True,
    val_split: float = 0.1,
) -> TopKSAE:
    """Train a TopK SAE on embeddings from HDF5 file."""
    import h5py
    import numpy as np
    import wandb
    from tqdm import tqdm

    if not no_wandb:
        wandb.init(project="icml2026-sae", config=vars(config))

    # Load embeddings
    print(f"Loading embeddings from {embeddings_path}...")
    with h5py.File(embeddings_path, "r") as f:
        # Data is (d, n), we need (n, d) for PyTorch
        emb = f["embeddings"][:].T  # (n, d)
        emb_dim = f.attrs["embedding_dim"]

    assert emb.shape[1] == config.input_dim, f"Dim mismatch: {emb.shape[1]} vs {config.input_dim}"

    # Split train/val
    n = emb.shape[0]
    n_val = int(n * val_split)
    indices = np.random.permutation(n)
    val_idx, train_idx = indices[:n_val], indices[n_val:]

    emb_train = torch.from_numpy(emb[train_idx]).float()
    emb_val = torch.from_numpy(emb[val_idx]).float()

    print(f"Train: {len(train_idx)}, Val: {len(val_idx)}")

    # Create model
    model = TopKSAE(config.input_dim, config.dict_size, config.k).to(config.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    # Training loop
    model.train()
    n_train = len(train_idx)

    actual_batch_size = min(config.batch_size, n_train)

    for step in tqdm(range(config.num_steps), desc="Training SAE"):
        # Sample batch
        batch_idx = np.random.choice(n_train, actual_batch_size, replace=False)
        x = emb_train[batch_idx].to(config.device)

        # Forward
        x_hat, z = model(x)
        loss = F.mse_loss(x_hat, x)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Normalize decoder columns to unit norm
        with torch.no_grad():
            model.decoder.weight.data = F.normalize(model.decoder.weight.data, dim=0)

        # Validation
        if step % config.val_every == 0 or step == config.num_steps - 1:
            model.eval()
            with torch.no_grad():
                # Compute on full val set in batches
                val_losses = []
                val_l0s = []
                for i in range(0, len(emb_val), config.batch_size):
                    x_v = emb_val[i:i+config.batch_size].to(config.device)
                    x_hat_v, z_v = model(x_v)
                    val_losses.append(F.mse_loss(x_hat_v, x_v).item())
                    val_l0s.append((z_v != 0).float().sum(dim=-1).mean().item())

                val_loss = np.mean(val_losses)
                val_l0 = np.mean(val_l0s)

            log_dict = {
                "step": step,
                "train_loss": loss.item(),
                "val_loss": val_loss,
                "val_l0": val_l0,
            }

            if not no_wandb:
                wandb.log(log_dict)

            if step % (config.val_every * 10) == 0:
                print(f"Step {step}: train_loss={loss.item():.4f}, val_loss={val_loss:.4f}, val_l0={val_l0:.1f}")

            model.train()

    # Save model
    output_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": vars(config),
    }, output_path)
    print(f"Saved SAE to {output_path}")

    # Also save dictionary in numpy format for comparison with KSVD
    D = model.get_dictionary().cpu().numpy()
    np.save(output_path.with_suffix(".D.npy"), D)
    print(f"Saved dictionary to {output_path.with_suffix('.D.npy')}")

    if not no_wandb:
        wandb.finish()

    return model
