"""Property-conditioned molecular generation CLI."""

import os
import torch
import pandas as pd

from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.models.latent_flow import LatentFlowPrior, FlowConfig
from moltenflow.models.surrogate_head import SurrogateHead
from moltenflow.inference.sample import sample_guided_smiles
from moltenflow.data.smiles_dataset import load_csv_dataset
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


def main(config_path: str = "configs/experiments/generation_guided.yaml") -> None:
    """Generate molecules with property guidance.

    This script performs property-conditioned generation using the SurrogateHead
    from the fine-tuned VAE checkpoint (no separate surrogate training needed):
    1. Sample z_0 ~ N(0, I)
    2. At each step: compute base velocity + guidance gradient from SurrogateHead
    3. Integrate guided flow to z_1
    4. Decode to SMILES and validate

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg.get("seed", 42))

    logger.info(f"Loaded config: {config_path}")
    logger.info("=== Property-Conditioned Generation ===")

    # Load dataset for vocab
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path", "data/raw/esol.csv")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    y_cols = data_cfg.get("y_cols", ["measured log solubility in mols per litre"])
    max_len = data_cfg.get("max_len", 128)
    representation = data_cfg.get("representation", "smiles")

    ds, _ = load_csv_dataset(
        csv_path,
        smiles_col,
        y_cols,
        max_len,
        seed=cfg.get("seed", 42),
        representation=representation,
    )
    vocab = ds.vocab

    # Load VAE checkpoint (contains both VAE and SurrogateHead)
    vae_checkpoint_path = cfg["vae"]["checkpoint_path"]
    logger.info(f"Loading VAE and SurrogateHead from {vae_checkpoint_path}")

    vae_cfg = VAEConfig(
        vocab_size=len(vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    vae = SmilesTokenVAE(vae_cfg, pad_id=vocab.pad_id)
    checkpoint = torch.load(vae_checkpoint_path, map_location="cpu")
    vae.load_state_dict(checkpoint["model"])
    vae.eval()
    logger.info("VAE loaded")

    # Load SurrogateHead from the same checkpoint
    surrogate_cfg = cfg.get("surrogate", {})
    surrogate = SurrogateHead(
        K=vae_cfg.K,
        d_latent=vae_cfg.d_latent,
        out_dim=surrogate_cfg.get("out_dim", len(y_cols)),
        cond_dim=surrogate_cfg.get("cond_dim", 0),
        hidden_dim=surrogate_cfg.get("hidden_dim", 256),
        aggregation=surrogate_cfg.get("aggregation", "mean"),
        dropout=surrogate_cfg.get("dropout", 0.1),
    )

    if "surrogate_head" in checkpoint:
        surrogate.load_state_dict(checkpoint["surrogate_head"])
        logger.info("SurrogateHead loaded from VAE checkpoint")
    else:
        raise ValueError(
            f"VAE checkpoint at {vae_checkpoint_path} does not contain surrogate_head. "
            "Make sure to use a fine-tuned VAE checkpoint with property supervision."
        )
    surrogate.eval()

    # Load flow
    flow_checkpoint = cfg["flow"]["checkpoint_path"]
    logger.info(f"Loading flow from {flow_checkpoint}")

    flow_cfg = FlowConfig(
        K=cfg["flow"].get("K", vae_cfg.K),
        d_latent=cfg["flow"].get("d_latent", vae_cfg.d_latent),
        d_model=cfg["flow"].get("d_model", 256),
        nhead=cfg["flow"].get("nhead", 8),
        layers=cfg["flow"].get("layers", 10),
        dim_ff=cfg["flow"].get("dim_ff", 1024),
        dropout=cfg["flow"].get("dropout", 0.1),
        time_dim=cfg["flow"].get("time_dim", 128),
    )

    flow = LatentFlowPrior(flow_cfg)
    flow_ckpt = torch.load(flow_checkpoint, map_location="cpu")
    flow.load_state_dict(flow_ckpt["model"])
    flow.eval()
    logger.info("Flow loaded")

    # Move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = vae.to(device)
    flow = flow.to(device)
    surrogate = surrogate.to(device)
    logger.info(f"Using device: {device}")

    # Generation parameters
    n_samples = cfg["generation"]["n_samples"]
    steps = cfg["generation"].get("steps", 80)
    batch_size = cfg["generation"].get("batch_size", 256)
    gamma = cfg["guidance"]["gamma"]
    target_values = cfg["guidance"]["target"]
    clip_norm = cfg["guidance"].get("clip_norm", None)
    normalize = cfg["guidance"].get("normalize", False)
    seed = cfg.get("seed", 42)

    target = torch.tensor([target_values], device=device, dtype=torch.float32)

    # Handle optional conditions
    c = None
    if surrogate_cfg.get("cond_dim", 0) > 0:
        cond_values = cfg["guidance"].get("conditions")
        if cond_values is not None:
            c = torch.tensor([cond_values], device=device, dtype=torch.float32)

    logger.info(f"Generating {n_samples} molecules in batches of {batch_size}")
    logger.info(f"Target properties: {target_values}")
    logger.info(f"Guidance strength (gamma): {gamma}")
    logger.info(f"Integration steps: {steps}")

    # Generate
    results = sample_guided_smiles(
        vae=vae,
        flow=flow,
        surrogate=surrogate,
        vocab=vocab,
        target=target,
        gamma=gamma,
        n=n_samples,
        steps=steps,
        seed=seed,
        c=c,
        clip_norm=clip_norm,
        normalize=normalize,
        representation=representation,
        batch_size=batch_size,
        show_progress=True,
    )

    # Save results
    output_dir = cfg.get("output", {}).get("dir", "outputs/generation")
    os.makedirs(output_dir, exist_ok=True)
    output_csv = os.path.join(output_dir, "generated_conditioned.csv")

    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)

    # Report statistics
    n_valid = df["valid"].sum()
    validity = n_valid / len(df) * 100
    n_unique = df[df["valid"]]["smiles"].nunique()
    uniqueness = n_unique / n_valid * 100 if n_valid > 0 else 0

    logger.info(f"Generated {len(df)} molecules")
    logger.info(f"Valid: {n_valid}/{len(df)} ({validity:.1f}%)")
    logger.info(f"Unique: {n_unique}/{n_valid} ({uniqueness:.1f}%)")
    logger.info(f"Results saved to {output_csv}")


if __name__ == "__main__":
    import sys

    config = sys.argv[1] if len(sys.argv) > 1 else "configs/experiments/generation_guided.yaml"
    main(config)
