"""Unconditioned molecular generation using VAE + latent flow."""

import os
import torch
import pandas as pd

from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.models.latent_flow import LatentFlowPrior, FlowConfig
from moltenflow.inference.sample import sample_smiles
from moltenflow.data.smiles_dataset import load_csv_dataset
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


def main(config_path: str = "configs/experiments/generation_uncond.yaml") -> None:
    """Generate molecules without property guidance.

    This script performs unconditioned generation:
    1. Sample z_0 ~ N(0, I) from standard normal
    2. Integrate latent flow: z_0 -> z_1
    3. Decode z_1 to SMILES via VAE decoder
    4. Validate and save results

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg.get("seed", 42))

    logger.info(f"Loaded config: {config_path}")
    logger.info("=== Unconditioned Generation ===")

    # Load VAE checkpoint
    vae_checkpoint = cfg["vae"]["checkpoint_path"]
    logger.info(f"Loading VAE from {vae_checkpoint}")

    # Load dataset to get vocab (needed for decoding)
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path", "data/raw/esol.csv")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    y_cols = data_cfg.get("y_cols", ["measured log solubility in mols per litre"])
    max_len = data_cfg.get("max_len", 128)
    representation = data_cfg.get("representation", "smiles")

    ds, _ = load_csv_dataset(
        csv_path,
        smiles_col,
        y_cols,
        max_len,
        seed=cfg.get("seed", 42),
        representation=representation,
    )
    vocab = ds.vocab

    # Create VAE
    vae_cfg = VAEConfig(
        vocab_size=len(vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    vae = SmilesTokenVAE(vae_cfg, pad_id=vocab.pad_id)
    checkpoint = torch.load(vae_checkpoint, map_location="cpu")
    vae.load_state_dict(checkpoint["model"])
    vae.eval()
    logger.info("VAE loaded successfully")

    # Load flow checkpoint
    flow_checkpoint = cfg["flow"]["checkpoint_path"]
    logger.info(f"Loading flow from {flow_checkpoint}")

    flow_cfg = FlowConfig(
        K=cfg["flow"].get("K", 8),
        d_latent=cfg["flow"].get("d_latent", 128),
        d_model=cfg["flow"].get("d_model", 256),
        nhead=cfg["flow"].get("nhead", 8),
        layers=cfg["flow"].get("layers", 10),
        dim_ff=cfg["flow"].get("dim_ff", 1024),
        dropout=cfg["flow"].get("dropout", 0.1),
        time_dim=cfg["flow"].get("time_dim", 128),
    )

    flow = LatentFlowPrior(flow_cfg)
    checkpoint = torch.load(flow_checkpoint, map_location="cpu")
    flow.load_state_dict(checkpoint["model"])
    flow.eval()
    logger.info("Flow loaded successfully")

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = vae.to(device)
    flow = flow.to(device)
    logger.info(f"Using device: {device}")

    # Generation parameters
    n_samples = cfg["generation"]["n_samples"]
    steps = cfg["generation"].get("steps", 80)
    seed = cfg.get("seed", 42)

    logger.info(f"Generating {n_samples} molecules (steps={steps}, seed={seed})")

    # Generate
    results = sample_smiles(
        vae=vae,
        flow=flow,
        vocab=vocab,
        n=n_samples,
        steps=steps,
        seed=seed,
        representation=representation,
    )

    # Save results
    output_dir = cfg.get("output", {}).get("dir", "outputs/generation")
    os.makedirs(output_dir, exist_ok=True)
    output_csv = os.path.join(output_dir, "generated_uncond.csv")

    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)

    # Report statistics
    n_valid = df["valid"].sum()
    validity = n_valid / len(df) * 100
    n_unique = df[df["valid"]]["smiles"].nunique()
    uniqueness = n_unique / n_valid * 100 if n_valid > 0 else 0

    logger.info(f"Generated {len(df)} molecules")
    logger.info(f"Valid: {n_valid}/{len(df)} ({validity:.1f}%)")
    logger.info(f"Unique: {n_unique}/{n_valid} ({uniqueness:.1f}%)")
    logger.info(f"Results saved to {output_csv}")
