"""Toy dataset generation from real molecular latent embeddings.

This module provides utilities for extracting latent embeddings from a trained VAE
and generating synthetic property targets for ablation studies.
"""

from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, Any
import numpy as np
import pandas as pd
import torch

from moltenflow.models.vae import SmilesTokenVAE
from moltenflow.data.smiles_dataset import SmilesDataset, batchify
from moltenflow.data.latents import generate_toy_targets
from moltenflow.utils.device import get_device
from moltenflow.utils.io import save_json
from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)


def pool_latents(z: torch.Tensor, method: str = "mean") -> torch.Tensor:
    """Pool K-token latents to single vectors.

    Args:
        z: Latent tensor of shape (B, K, d_latent)
        method: Pooling method - "mean", "max", or "first"

    Returns:
        Tensor of shape (B, d_latent)

    Raises:
        ValueError: If unknown pooling method
    """
    if method == "mean":
        return z.mean(dim=1)
    elif method == "max":
        return z.max(dim=1)[0]
    elif method == "first":
        return z[:, 0, :]
    else:
        raise ValueError(f"Unknown pooling method: {method}. Use 'mean', 'max', or 'first'.")


@torch.no_grad()
def extract_latents_from_vae(
    vae: SmilesTokenVAE,
    ds: SmilesDataset,
    indices: np.ndarray,
    batch_size: int = 256,
    pool_method: str = "mean",
    use_mu: bool = True,
) -> tuple[np.ndarray, list[str]]:
    """Extract latent embeddings from a trained VAE.

    Args:
        vae: Trained VAE model (will be set to eval mode)
        ds: SmilesDataset containing molecules
        indices: Indices of samples to extract
        batch_size: Batch size for inference
        pool_method: How to pool K tokens - "mean", "max", or "first"
        use_mu: If True, use mean (mu) of posterior. If False, sample z.

    Returns:
        Tuple of (latents array of shape (N, d_latent), list of SMILES strings)
    """
    device = get_device()
    vae = vae.to(device).eval()

    all_z = []
    all_smiles = []

    for batch in batchify(ds, indices, batch_size, shuffle=False):
        x = torch.tensor(batch["x"], device=device)
        z, mu, logvar = vae.encode(x)

        # Use mu for deterministic embeddings, or z for sampled
        z_out = mu if use_mu else z

        # Pool K tokens to single vector
        z_pooled = pool_latents(z_out, method=pool_method)

        all_z.append(z_pooled.cpu().numpy())
        all_smiles.extend(batch["smiles"])

    z_array = np.concatenate(all_z, axis=0)
    return z_array, all_smiles


def generate_toydata_from_latents(
    z: np.ndarray | torch.Tensor,
    smiles: list[str],
    n_properties: int = 2,
    seed: int = 42,
) -> pd.DataFrame:
    """Generate toy dataset with synthetic properties from latent embeddings.

    Args:
        z: Latent embeddings of shape (N, d_latent)
        smiles: List of SMILES strings (must match first dim of z)
        n_properties: Number of synthetic properties to generate
        seed: Random seed for reproducibility

    Returns:
        DataFrame with columns: smiles, z_pooled, y_0, y_1, ...
    """
    if isinstance(z, np.ndarray):
        z_tensor = torch.from_numpy(z)
    else:
        z_tensor = z

    # Generate synthetic targets using existing function
    y = generate_toy_targets(z_tensor, n_properties=n_properties, seed=seed)
    y_np = y.numpy()

    # Build dataframe
    data: Dict[str, Any] = {"smiles": smiles}

    # Store pooled latents as lists (for parquet compatibility)
    data["z_pooled"] = [z[i].tolist() for i in range(len(z))]

    # Add property columns
    for i in range(n_properties):
        data[f"y_{i}"] = y_np[:, i]

    return pd.DataFrame(data)


def save_toydata(
    df: pd.DataFrame,
    output_dir: str,
    name: str = "real_latents_toydata",
    metadata: dict | None = None,
) -> tuple[Path, Path]:
    """Save toy dataset and metadata.

    Args:
        df: DataFrame with toy dataset
        output_dir: Directory to save files
        name: Base name for output files
        metadata: Optional metadata dict to save

    Returns:
        Tuple of (parquet_path, metadata_path)
    """
    os.makedirs(output_dir, exist_ok=True)

    parquet_path = Path(output_dir) / f"{name}.parquet"
    df.to_parquet(parquet_path, index=False, engine="pyarrow")
    logger.info(f"Saved toy dataset to {parquet_path}")

    metadata_path = Path(output_dir) / f"{name}_metadata.json"
    if metadata is None:
        metadata = {}
    metadata["n_samples"] = len(df)
    metadata["n_properties"] = len([c for c in df.columns if c.startswith("y_")])
    metadata["latent_dim"] = len(df["z_pooled"].iloc[0]) if len(df) > 0 else 0

    save_json(str(metadata_path), metadata)
    logger.info(f"Saved metadata to {metadata_path}")

    return parquet_path, metadata_path


def load_toydata(parquet_path: str) -> tuple[pd.DataFrame, np.ndarray, np.ndarray]:
    """Load toy dataset from parquet.

    Args:
        parquet_path: Path to parquet file

    Returns:
        Tuple of (DataFrame, latents array (N, d), properties array (N, P))
    """
    df = pd.read_parquet(parquet_path)

    # Extract latents
    z = np.array(df["z_pooled"].tolist())

    # Extract properties
    y_cols = sorted([c for c in df.columns if c.startswith("y_")])
    y = df[y_cols].values

    return df, z, y


def create_toydata_from_vae(
    vae: SmilesTokenVAE,
    ds: SmilesDataset,
    indices: np.ndarray,
    output_dir: str,
    n_properties: int = 2,
    batch_size: int = 256,
    pool_method: str = "mean",
    seed: int = 42,
    vae_checkpoint_path: str | None = None,
) -> tuple[pd.DataFrame, Path, Path]:
    """Full pipeline: extract latents from VAE and generate toy dataset.

    Args:
        vae: Trained VAE model
        ds: SmilesDataset containing molecules
        indices: Indices of samples to use
        output_dir: Directory to save outputs
        n_properties: Number of synthetic properties
        batch_size: Batch size for latent extraction
        pool_method: How to pool K tokens
        seed: Random seed
        vae_checkpoint_path: Optional path to record in metadata

    Returns:
        Tuple of (DataFrame, parquet_path, metadata_path)
    """
    logger.info(f"Extracting latents from {len(indices)} samples...")
    z, smiles = extract_latents_from_vae(
        vae=vae,
        ds=ds,
        indices=indices,
        batch_size=batch_size,
        pool_method=pool_method,
        use_mu=True,
    )
    logger.info(f"Extracted latents with shape {z.shape}")

    logger.info(f"Generating {n_properties} synthetic properties...")
    df = generate_toydata_from_latents(z, smiles, n_properties=n_properties, seed=seed)

    metadata = {
        "vae_checkpoint": vae_checkpoint_path,
        "pool_method": pool_method,
        "property_seed": seed,
        "n_properties": n_properties,
    }

    parquet_path, metadata_path = save_toydata(df, output_dir, metadata=metadata)

    return df, parquet_path, metadata_path
