"""Fine-tune pretrained VAE with property supervision."""

import os
import torch

from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.models.surrogate_head import SurrogateHead
from moltenflow.training.train_vae import train_vae
from moltenflow.data.smiles_dataset import load_csv_dataset
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


def main(config_path: str = "configs/finetune.yaml") -> None:
    """Fine-tune pretrained VAE with property supervision.

    This script loads a pretrained VAE checkpoint and fine-tunes it with
    property prediction via a surrogate head. The property loss gradients
    backpropagate through the encoder, orienting the latent space to be
    property-aware.

    Training modes:
    - reconstruction: Standard VAE loss only (no property supervision)
    - joint: Reconstruction + property prediction loss

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg["train"]["seed"])

    logger.info(f"Loading config from {config_path}")
    logger.info("=== VAE Fine-tuning ===")

    # Load dataset with properties
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    y_cols = data_cfg.get("y_cols")
    max_len = data_cfg.get("max_len", 128)
    representation = data_cfg.get("representation", "smiles")

    if csv_path is None:
        raise ValueError("data.csv_path must be specified in config")
    if y_cols is None or len(y_cols) == 0:
        raise ValueError("data.y_cols must be specified for fine-tuning")

    logger.info(f"Loading dataset from {csv_path}")
    logger.info(f"Property columns: {y_cols}")

    ds, splits = load_csv_dataset(
        csv_path,
        smiles_col,
        y_cols,
        max_len,
        seed=cfg["train"]["seed"],
        representation=representation,
    )

    logger.info(
        f"Dataset: {len(ds)} molecules, vocab size: {len(ds.vocab.id_to_token)}, max_len: {max_len}"
    )
    logger.info(
        f"Splits: train={len(splits['train'])}, val={len(splits['val'])}, test={len(splits['test'])}"
    )
    logger.info(f"Properties: {len(y_cols)} ({', '.join(y_cols)})")

    # Create VAE config
    vae_cfg = VAEConfig(
        vocab_size=len(ds.vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    model = SmilesTokenVAE(vae_cfg, pad_id=ds.vocab.pad_id)

    # Load pretrained checkpoint if specified
    pretrained_path = cfg.get("vae", {}).get("pretrained_checkpoint")
    if pretrained_path:
        logger.info(f"Loading pretrained VAE from {pretrained_path}")
        checkpoint = torch.load(pretrained_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
        logger.info("Pretrained VAE loaded successfully")
    else:
        logger.warning("No pretrained checkpoint specified - training from scratch")

    logger.info(f"VAE: {sum(p.numel() for p in model.parameters()):,} parameters")

    # Setup property prediction
    training_mode = cfg["train"].get("mode", "joint")
    property_weight = cfg["train"].get("property_weight", 1.0)
    freeze_decoder = cfg["train"].get("freeze_decoder", False)

    surrogate_head = None
    if training_mode == "joint" and property_weight > 0:
        surrogate_head = SurrogateHead(
            K=vae_cfg.K,
            d_latent=vae_cfg.d_latent,
            out_dim=len(y_cols),
            hidden_dim=cfg.get("surrogate", {}).get("hidden_dim", 256),
            aggregation=cfg.get("surrogate", {}).get("aggregation", "mean"),
            dropout=cfg.get("surrogate", {}).get("dropout", 0.1),
        )
        logger.info(
            f"Created surrogate head: {sum(p.numel() for p in surrogate_head.parameters()):,} parameters"
        )
        logger.info(f"Training mode: {training_mode}, property_weight: {property_weight}")
    else:
        logger.info("Training mode: reconstruction only")

    if freeze_decoder:
        logger.info("Decoder will be frozen during fine-tuning")

    # Setup output directory
    workdir = cfg.get("output", {}).get("workdir", "outputs/finetune")
    os.makedirs(workdir, exist_ok=True)
    logger.info(f"Output directory: {workdir}")

    # Train
    logger.info("Starting VAE fine-tuning...")
    train_vae(
        ds=ds,
        splits=splits,
        model=model,
        workdir=workdir,
        epochs=cfg["train"]["epochs"],
        batch_size=cfg["train"]["batch_size"],
        lr=cfg["train"]["lr"],
        beta_max=cfg["vae"].get("beta", 0.12),
        beta_warmup_frac=cfg["train"].get("beta_warmup_frac", 0.35),
        grad_clip=cfg["train"].get("grad_clip", 1.0),
        seed=cfg["train"]["seed"],
        surrogate_head=surrogate_head,
        property_weight=property_weight,
        freeze_decoder=freeze_decoder,
    )

    logger.info(f"VAE fine-tuning complete! Models saved to {workdir}")
    logger.info(f"Fine-tuned checkpoint: {workdir}/vae_best.pt")
