#!/usr/bin/env python
"""Run budgeted multi-objective optimization experiments.

This script runs optimization experiments comparing:
- MoltenFlow: Guided flow optimization (flow + surrogate guidance)
- Gradient Ascent: Pure gradient ascent (no flow, just surrogate gradient) - ablation
- BO (2GP): Bayesian optimization with two independent GPs + qEHVI
- BO (MOGP): Bayesian optimization with multi-output GP + qEHVI

Example usage:
    # MoltenFlow with guided flow
    python scripts/run_budgeted_optimization.py --method moltenflow --budget 100 --seed 42

    # Gradient ascent ablation (no flow)
    python scripts/run_budgeted_optimization.py --method gradient_ascent --budget 100 --seed 42

    # Bayesian optimization
    python scripts/run_budgeted_optimization.py --method bo_mogp --init near_pareto --budget 500
"""

import argparse
import os
from pathlib import Path

import pandas as pd
import torch

from moltenflow.data.smiles_dataset import load_csv_dataset
from moltenflow.models.latent_flow import FlowConfig, LatentFlowPrior
from moltenflow.models.surrogate_head import SurrogateHead
from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.optimization import (
    BudgetedOptimizer,
    MoltenFlowProposer,
)
from moltenflow.tokenizer.tokenizer import Vocab, build_vocab
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


def load_data_and_vocab(cfg: dict, seed: int) -> tuple[list[str], Vocab]:
    """Load SMILES data and vocabulary, ensuring vocab matches trained models.

    Supports three modes:
    1. Load vocab from a saved file (vocab_path in config)
    2. Load parquet with max_molecules to build consistent vocab
    3. Load CSV (may have different vocab - use with caution)

    Args:
        cfg: Config dict with 'data' section
        seed: Random seed

    Returns:
        Tuple of (smiles_list, vocab)
    """
    data_cfg = cfg["data"]
    max_len = data_cfg.get("max_len", 128)

    # Option 1: Load vocab from saved file
    vocab_path = data_cfg.get("vocab_path")
    if vocab_path and os.path.exists(vocab_path):
        logger.info(f"Loading vocab from {vocab_path}")
        vocab = Vocab.load(vocab_path)

        # Still need to load SMILES data
        if data_cfg.get("parquet_path") and os.path.exists(data_cfg["parquet_path"]):
            df = pd.read_parquet(data_cfg["parquet_path"])
        elif data_cfg.get("csv_path") and os.path.exists(data_cfg["csv_path"]):
            df = pd.read_csv(data_cfg["csv_path"])
        else:
            raise ValueError("No valid data path specified in config")

        smiles_col = data_cfg.get("smiles_col", "smiles")
        smiles_list = df[smiles_col].dropna().tolist()
        logger.info(f"Loaded {len(smiles_list)} molecules with vocab size {len(vocab.id_to_token)}")
        return smiles_list, vocab

    # Option 2: Load parquet with max_molecules (preferred for consistency)
    parquet_path = data_cfg.get("parquet_path")
    if parquet_path and os.path.exists(parquet_path):
        logger.info(f"Loading data from parquet: {parquet_path}")
        df = pd.read_parquet(parquet_path)
        smiles_col = data_cfg.get("smiles_col", "smiles")
        max_molecules = data_cfg.get("max_molecules")

        if max_molecules:
            df = df.head(max_molecules)
            logger.info(f"Limited to {max_molecules} molecules (matching training config)")

        smiles_list = df[smiles_col].dropna().tolist()
        logger.info(f"Building vocab from {len(smiles_list)} molecules...")
        vocab = build_vocab(smiles_list)
        logger.info(f"Vocab size: {len(vocab.id_to_token)}")
        return smiles_list, vocab

    # Option 3: Load CSV (fallback)
    csv_path = data_cfg.get("csv_path")
    if csv_path and os.path.exists(csv_path):
        logger.info(f"Loading data from CSV: {csv_path}")
        logger.warning(
            "Using CSV without max_molecules may produce different vocab than training. "
            "Consider using parquet with max_molecules matching training config."
        )

        ds, _ = load_csv_dataset(
            csv_path=csv_path,
            smiles_col=data_cfg.get("smiles_col", "smiles"),
            y_cols=[],
            max_len=max_len,
            seed=seed,
        )
        return ds.smiles, ds.vocab

    raise ValueError(
        "No valid data path specified. Set 'parquet_path' or 'csv_path' in data config."
    )


def load_models(cfg: dict, vocab, device: torch.device, load_flow: bool = True):
    """Load VAE, flow, and surrogate models from config.

    Args:
        cfg: Configuration dict
        vocab: Vocabulary
        device: Torch device
        load_flow: Whether to load the flow model. Set False for gradient_ascent.

    Returns:
        Tuple of (vae, flow, surrogate) where flow may be None if load_flow=False
    """
    # Load VAE
    vae_cfg = VAEConfig(
        vocab_size=len(vocab.id_to_token),
        max_len=cfg["data"]["max_len"],
        d_model=cfg["vae"]["d_model"],
        nhead=cfg["vae"]["nhead"],
        enc_layers=cfg["vae"]["enc_layers"],
        dec_layers=cfg["vae"]["dec_layers"],
        dim_ff=cfg["vae"]["dim_ff"],
        dropout=cfg["vae"]["dropout"],
        K=cfg["vae"]["K"],
        d_latent=cfg["vae"]["latent_dim"],
    )

    vae = SmilesTokenVAE(vae_cfg, pad_id=vocab.pad_id)
    vae_ckpt = torch.load(cfg["vae"]["checkpoint_path"], map_location="cpu")
    vae.load_state_dict(vae_ckpt["model"])
    vae = vae.to(device)
    vae.eval()
    logger.info(f"Loaded VAE from {cfg['vae']['checkpoint_path']}")

    # Load flow (optional - not needed for gradient_ascent)
    flow = None
    if load_flow:
        flow_cfg = FlowConfig(
            K=cfg["vae"]["K"],
            d_latent=cfg["vae"]["latent_dim"],
            d_model=cfg["flow"]["d_model"],
            nhead=cfg["flow"]["nhead"],
            layers=cfg["flow"]["layers"],
            dim_ff=cfg["flow"]["dim_ff"],
            dropout=cfg["flow"]["dropout"],
            time_dim=cfg["flow"]["time_dim"],
        )

        flow = LatentFlowPrior(flow_cfg)
        flow_ckpt = torch.load(cfg["flow"]["checkpoint_path"], map_location="cpu")
        flow.load_state_dict(flow_ckpt["model"])
        flow = flow.to(device)
        flow.eval()
        logger.info(f"Loaded flow from {cfg['flow']['checkpoint_path']}")
    else:
        logger.info("Skipping flow model (not needed for gradient_ascent)")

    # Load surrogate with output bounds if specified
    surrogate_cfg = cfg["surrogate"]
    output_bounds = None
    if "output_bounds" in surrogate_cfg:
        # Parse output bounds from config
        bounds_cfg = surrogate_cfg["output_bounds"]
        output_bounds = []
        # Assume properties in order: qed, sas (matching training)
        for prop in ["qed", "sas"]:
            b = bounds_cfg.get(prop)
            output_bounds.append(tuple(b) if b is not None else None)

    surrogate = SurrogateHead(
        K=cfg["vae"]["K"],
        d_latent=cfg["vae"]["latent_dim"],
        out_dim=surrogate_cfg["out_dim"],
        hidden_dim=surrogate_cfg["hidden_dim"],
        aggregation=surrogate_cfg["aggregation"],
        dropout=surrogate_cfg["dropout"],
        cond_dim=surrogate_cfg.get("cond_dim", 0),
        output_bounds=output_bounds,
    )

    if "surrogate_head" in vae_ckpt:
        surrogate.load_state_dict(vae_ckpt["surrogate_head"])
        logger.info("Loaded surrogate from VAE checkpoint")
    else:
        logger.warning("No surrogate_head in VAE checkpoint, using random init")

    surrogate = surrogate.to(device)
    surrogate.eval()

    return vae, flow, surrogate


def create_proposer(
    method: str,
    cfg: dict,
    vae: SmilesTokenVAE,
    flow: LatentFlowPrior | None,
    surrogate: SurrogateHead,
    vocab,
    device: torch.device,
    seed: int,
):
    """Create the appropriate proposer based on method name.

    Supported methods:
        - moltenflow: Guided flow optimization (flow + surrogate guidance)
        - gradient_ascent: Pure gradient ascent (no flow, just surrogate gradient)
        - bo_2gp: Bayesian optimization with two independent GPs
        - bo_mogp: Bayesian optimization with multi-output GP
    """
    data_cfg = cfg.get("data", {})
    representation = data_cfg.get("representation", "smiles")

    if method in ("moltenflow", "gradient_ascent"):
        # Use method-specific config, fall back to moltenflow config
        if method == "gradient_ascent":
            method_cfg = cfg.get("gradient_ascent", cfg.get("moltenflow", {}))
            use_flow = False
            # For gradient ascent, gamma is not used, only step_size
            gamma = None
        else:
            method_cfg = cfg.get("moltenflow", {})
            use_flow = method_cfg.get("use_flow", True)
            gamma = method_cfg.get("gamma", 1.0)

        return MoltenFlowProposer(
            vae=vae,
            flow=flow if use_flow else None,
            surrogate=surrogate,
            vocab=vocab,
            gamma=gamma if gamma is not None else method_cfg.get("gamma", 1.0),
            sigma=method_cfg.get("sigma", 0.1),
            steps=method_cfg.get("steps", 30),
            t_start=method_cfg.get("t_start", 0.9),
            seed_selection=method_cfg.get("seed_selection", "uniform"),
            clip_norm=method_cfg.get("clip_norm"),
            normalize_gradient=method_cfg.get("normalize_gradient", False),
            use_flow=use_flow,
            step_size=method_cfg.get("step_size"),
            device=device,
            seed=seed,
            representation=representation,
            # Diversity-weighted selection parameters
            diversity_threshold=method_cfg.get("diversity_threshold", 0.7),
            diversity_penalty=method_cfg.get("diversity_penalty", 2.0),
            diversity_window=method_cfg.get("diversity_window", 20),
            pareto_weight=method_cfg.get("pareto_weight", 2.0),
        )

    elif method in ("bo_2gp", "bo_mogp"):
        # Import BO proposers (requires optional dependencies)
        try:
            from moltenflow.optimization.proposers.bo import (
                MOGPProposer,
                TwoGPProposer,
            )
        except ImportError as e:
            raise ImportError(
                "BO proposers require optional dependencies. "
                "Install with: pip install 'moltenflow[bo]'"
            ) from e

        bo_cfg = cfg.get("bo", {})
        hv_cfg = cfg.get("hypervolume", {})
        ref_point = tuple(hv_cfg.get("ref_point", [0.0, -10.0]))

        proposer_cls = TwoGPProposer if method == "bo_2gp" else MOGPProposer
        return proposer_cls(
            ref_point=ref_point,
            num_restarts=bo_cfg.get("num_restarts", 10),
            raw_samples=bo_cfg.get("raw_samples", 512),
            latent_aggregation=bo_cfg.get("latent_aggregation", "flatten"),
            device=device,
            seed=seed,
        )

    else:
        raise ValueError(
            f"Unknown method: {method}. Choose from: moltenflow, gradient_ascent, bo_2gp, bo_mogp"
        )


def main():
    parser = argparse.ArgumentParser(description="Run budgeted multi-objective optimization")
    parser.add_argument(
        "--config",
        type=str,
        default="configs/experiments/budgeted_optimization.yaml",
        help="Path to config YAML file",
    )
    parser.add_argument(
        "--method",
        type=str,
        choices=["moltenflow", "gradient_ascent", "bo_2gp", "bo_mogp"],
        default="moltenflow",
        help="Optimization method: moltenflow (flow+guidance), gradient_ascent (no flow), bo_2gp, bo_mogp",
    )
    parser.add_argument(
        "--init",
        type=str,
        choices=["random", "near_pareto"],
        default=None,
        help="Initialization method (overrides config)",
    )
    parser.add_argument(
        "--budget",
        type=int,
        default=None,
        help="Oracle call budget (overrides config)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed (overrides config)",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=None,
        help="Output directory (overrides config)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        help="Torch device (default: auto)",
    )

    args = parser.parse_args()

    # Load config
    cfg = load_yaml(args.config)

    # Override with command line args
    if args.init is not None:
        cfg["init"]["method"] = args.init
    if args.budget is not None:
        cfg["optimization"]["budget"] = args.budget
    if args.seed is not None:
        cfg["seed"] = args.seed

    seed = cfg.get("seed", 42)
    set_seed(seed)

    # Device
    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Output directory
    if args.output_dir:
        output_dir = Path(args.output_dir)
    else:
        base_dir = Path(cfg.get("output", {}).get("dir", "outputs/budgeted_optimization"))
        run_name = f"{args.method}_{cfg['init']['method']}_seed{seed}"
        output_dir = base_dir / run_name

    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory: {output_dir}")

    # Load data and vocab
    smiles_pool, vocab = load_data_and_vocab(cfg, seed)
    logger.info(f"Loaded {len(smiles_pool)} molecules with vocab size {len(vocab.id_to_token)}")

    # Load models (skip flow for gradient_ascent method)
    load_flow = args.method != "gradient_ascent"
    vae, flow, surrogate = load_models(cfg, vocab, device, load_flow=load_flow)

    # Create proposer
    proposer = create_proposer(
        method=args.method,
        cfg=cfg,
        vae=vae,
        flow=flow,
        surrogate=surrogate,
        vocab=vocab,
        device=device,
        seed=seed,
    )

    # Create optimizer
    opt_cfg = cfg["optimization"]
    hv_cfg = cfg.get("hypervolume", {})
    data_cfg = cfg.get("data", {})
    representation = data_cfg.get("representation", "smiles")

    optimizer = BudgetedOptimizer(
        proposer=proposer,
        vae=vae,
        vocab=vocab,
        budget=opt_cfg["budget"],
        n_init=opt_cfg["n_init"],
        init_method=cfg["init"]["method"],
        batch_size=opt_cfg.get("batch_size", 1),
        ref_point=hv_cfg.get("ref_point"),
        output_dir=output_dir,
        seed=seed,
        device=device,
        representation=representation,
    )

    # Run optimization
    logger.info(
        f"Starting {args.method} optimization: "
        f"budget={opt_cfg['budget']}, n_init={opt_cfg['n_init']}, "
        f"init={cfg['init']['method']}"
    )

    result = optimizer.run(smiles_pool)

    # Print results
    logger.info("=" * 60)
    logger.info("Optimization Results")
    logger.info("=" * 60)
    logger.info(f"Method: {args.method}")
    logger.info(f"Init: {cfg['init']['method']}")
    logger.info(f"Budget: {opt_cfg['budget']}")
    logger.info(f"Final HV: {result.final_hv:.4f}")
    logger.info(f"HV Improvement: {result.hv_improvement:.4f}")
    logger.info(f"Validity Rate: {result.validity_rate:.2%}")
    logger.info(f"Pareto Size: {len(result.pareto_smiles)}")
    logger.info(f"Output: {output_dir}")


if __name__ == "__main__":
    main()
