"""Pretrain VAE on large unlabeled SMILES corpus."""

import os
from pathlib import Path
import pandas as pd

from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.training.train_vae import train_vae
from moltenflow.data.smiles_dataset import load_csv_dataset
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


def main(config_path: str = "configs/pretrain.yaml") -> None:
    """Pretrain VAE on large SMILES corpus.

    This script trains a VAE on a large dataset of SMILES strings without
    property labels. The resulting checkpoint can be used for:
    - Direct generation (unconditional)
    - Fine-tuning with property supervision
    - Latent flow training

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg["train"]["seed"])

    logger.info(f"Loading config from {config_path}")
    logger.info("=== VAE Pretraining ===")

    # Load dataset
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    max_len = data_cfg.get("max_len", 128)
    subset = data_cfg.get("subset", None)  # For debugging with smaller dataset
    representation = data_cfg.get("representation", "smiles")

    if csv_path is None:
        raise ValueError("data.csv_path must be specified in config")

    logger.info(f"Loading dataset from {csv_path}")

    # For pretraining, we use a dummy property column since we don't need properties
    # The dataset loader requires y_cols, so we create a dummy one
    df = pd.read_csv(csv_path)

    if subset is not None:
        logger.info(f"Using subset of {subset} samples for debugging")
        df = df.head(subset)

    # Add dummy property column if not present
    if "dummy_prop" not in df.columns:
        df["dummy_prop"] = 0.0

    # Save temporary CSV with dummy property
    temp_csv = Path(csv_path).parent / "temp_pretrain.csv"
    df.to_csv(temp_csv, index=False)

    try:
        ds, splits = load_csv_dataset(
            str(temp_csv),
            smiles_col,
            ["dummy_prop"],
            max_len,
            seed=cfg["train"]["seed"],
            representation=representation,
        )
    finally:
        # Clean up temp file
        if temp_csv.exists():
            temp_csv.unlink()

    logger.info(
        f"Dataset: {len(ds)} molecules, vocab size: {len(ds.vocab.id_to_token)}, max_len: {max_len}"
    )
    logger.info(
        f"Splits: train={len(splits['train'])}, val={len(splits['val'])}, test={len(splits['test'])}"
    )

    # Create model
    vae_cfg = VAEConfig(
        vocab_size=len(ds.vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    model = SmilesTokenVAE(vae_cfg, pad_id=ds.vocab.pad_id)
    logger.info(f"Created VAE with {sum(p.numel() for p in model.parameters()):,} parameters")

    # Setup output directory
    workdir = cfg.get("output", {}).get("workdir", "outputs/pretrain")
    os.makedirs(workdir, exist_ok=True)
    logger.info(f"Output directory: {workdir}")

    # Train (reconstruction only, no property loss)
    logger.info("Starting VAE pretraining (reconstruction only)...")
    train_vae(
        ds=ds,
        splits=splits,
        model=model,
        workdir=workdir,
        epochs=cfg["train"]["epochs"],
        batch_size=cfg["train"]["batch_size"],
        lr=cfg["train"]["lr"],
        lr_schedule=cfg["train"].get("lr_schedule", "cosine"),
        lr_min=cfg["train"].get("lr_min", 0.0),
        beta_max=cfg["vae"].get("beta", 0.12),
        beta_warmup_frac=cfg["train"].get("beta_warmup_frac", 0.35),
        grad_clip=cfg["train"].get("grad_clip", 1.0),
        seed=cfg["train"]["seed"],
        surrogate_head=None,  # No property prediction
        property_weight=0.0,
        freeze_decoder=False,
    )

    logger.info(f"VAE pretraining complete! Models saved to {workdir}")
    logger.info(f"Pretrained checkpoint: {workdir}/vae_best.pt")
