"""Property-guided molecular optimization using guided latent flow."""

import os
import torch
import torch.nn as nn
import pandas as pd
from typing import List, Dict, Optional

from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.models.latent_flow import LatentFlowPrior, FlowConfig
from moltenflow.models.surrogate_head import SurrogateHead
from moltenflow.guidance.guidance import compute_guidance
from moltenflow.guidance.objectives import mse_objective
from moltenflow.inference.sample import decode_greedy_with_ids
from moltenflow.tokenizer.tokenizer import (
    decode_ids,
    encode,
    selfies_to_smiles,
    smiles_to_selfies,
)
from moltenflow.data.smiles_dataset import load_csv_dataset
from moltenflow.data.data_utils import canonicalize_smiles
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


def optimize_molecules(
    vae: SmilesTokenVAE,
    flow: Optional[LatentFlowPrior],
    surrogate: nn.Module,
    vocab,
    input_smiles: List[str],
    target: torch.Tensor,
    gamma: float,
    sigma: float,
    steps: int,
    c: Optional[torch.Tensor] = None,
    clip_norm: float = None,
    normalize: bool = False,
    representation: str = "smiles",
    loss_fn: Optional[nn.Module] = None,
    verbose: bool = False,
    t_start: float = 0.0,
    use_flow: bool = True,
    step_size: Optional[float] = None,
) -> List[Dict]:
    """Optimize existing molecules toward target properties.

    Pipeline:
    1. Encode input SMILES to latent z_0
    2. Add noise: z_t = z_0 + sigma * noise
    3. Apply guided flow integration (or pure gradient ascent if use_flow=False)
    4. Decode to improved SMILES

    Args:
        vae: VAE model
        flow: Flow model (can be None if use_flow=False)
        surrogate: SurrogateHead (expects 3D input: B, K, D)
        vocab: Vocabulary
        input_smiles: List of input SMILES strings
        target: Target properties (1, P) or (n, P)
        gamma: Guidance strength (used when use_flow=True)
        sigma: Noise level for perturbation
        steps: Integration/optimization steps
        c: Optional conditional variables (1, cond_dim) or (n, cond_dim)
        clip_norm: Optional gradient clipping
        normalize: Whether to normalize gradients
        loss_fn: Loss function for guidance. Default: MSE loss.
                 Can use directional_objective() for maximize/minimize optimization.
        verbose: If True, log diagnostic info about gradient/velocity magnitudes
        t_start: Starting time for integration. Default 0.0 (full integration).
                 For local optimization of existing molecules, use t_start close to 1.0
                 (e.g., 0.9 or 0.95) since the encoded latent is already near the
                 data manifold. This prevents the flow from pushing too aggressively.
                 Ignored when use_flow=False.
        use_flow: If True, use flow velocity + guidance. If False, pure gradient ascent.
        step_size: Step size for gradient ascent (used when use_flow=False).
                   Default: gamma / steps if not provided.

    Returns:
        List of dicts with input_smiles, output_smiles, valid, properties
    """
    dev = next(vae.parameters()).device
    max_len = vae.cfg.max_len

    # Encode input molecules
    x_list = []
    valid_inputs = []

    for smi in input_smiles:
        try:
            # Always treat `input_smiles` as SMILES strings.
            cs_in = canonicalize_smiles(smi)
            if cs_in is None:
                raise ValueError("Invalid SMILES")

            if representation == "smiles":
                seq = cs_in
            elif representation == "selfies":
                sf = smiles_to_selfies(cs_in)
                if sf is None:
                    raise ValueError("Could not encode SMILES to SELFIES")
                seq = sf
            else:
                raise ValueError(
                    f"Unknown representation '{representation}'. Expected 'smiles' or 'selfies'."
                )

            x_list.append(encode(seq, vocab, max_len, representation=representation))
            valid_inputs.append(True)
        except Exception:
            valid_inputs.append(False)
            x_list.append(None)

    # Filter to valid inputs only
    valid_indices = [i for i, v in enumerate(valid_inputs) if v]
    if len(valid_indices) == 0:
        logger.warning("No valid input SMILES provided")
        return []

    x_batch = torch.tensor([x_list[i] for i in valid_indices], device=dev, dtype=torch.long)
    n_valid = len(valid_indices)

    # Encode to latent
    with torch.no_grad():
        z0, _, _ = vae.encode(x_batch)

    # Add noise
    noise = torch.randn_like(z0)
    z_noisy = z0 + sigma * noise

    # Predict properties before optimization (SurrogateHead accepts 3D)
    with torch.no_grad():
        c_batch = c.expand(n_valid, -1) if c is not None and c.size(0) == 1 else c
        pred_before = surrogate(z0, c_batch).detach().cpu().numpy()

    # Broadcast target if needed
    if target.size(0) == 1:
        target = target.expand(n_valid, -1)

    # Broadcast conditions if needed
    if c is not None and c.size(0) == 1:
        c = c.expand(n_valid, -1)

    # Guided integration (or pure gradient ascent)
    vae.eval()
    if flow is not None:
        flow.eval()
    surrogate.eval()

    if loss_fn is None:
        loss_fn = mse_objective()

    z = z_noisy

    if use_flow:
        if flow is None:
            raise ValueError("Flow model required when use_flow=True")

        # Compute integration range: t goes from t_start to 1.0
        t_end = 1.0
        dt = (t_end - t_start) / steps

        if verbose:
            logger.info(
                f"Flow integration: t_start={t_start}, t_end={t_end}, steps={steps}, dt={dt:.4f}"
            )

        # Diagnostic accumulators
        if verbose:
            v_norms = []
            g_norms = []

        for i in range(steps):
            t_val = t_start + i * dt
            t = torch.full((z.size(0),), t_val, device=dev)

            # Base velocity
            v = flow(z, t)

            # Guidance (SurrogateHead accepts 3D input directly)
            g = compute_guidance(
                z,
                target,
                surrogate,
                loss_fn,
                c=c,
                clip_norm=clip_norm,
                normalize=normalize,
            )

            # Collect diagnostics at first, middle, and last steps
            if verbose and i in (0, steps // 2, steps - 1):
                v_norm = torch.norm(v.reshape(n_valid, -1), dim=1).mean().item()
                g_norm = torch.norm(g.reshape(n_valid, -1), dim=1).mean().item()
                v_norms.append((i, v_norm))
                g_norms.append((i, g_norm))

            # Update: flow + guidance
            z = z + dt * (v - gamma * g)
            z = z.detach()

        # Log diagnostics
        if verbose and v_norms:
            logger.info(f"Optimization diagnostics (gamma={gamma}, normalize={normalize}):")
            for (step, v_n), (_, g_n) in zip(v_norms, g_norms):
                ratio = v_n / (gamma * g_n + 1e-10)
                logger.info(
                    f"  step {step}: |v|={v_n:.4e}, |g|={g_n:.4e}, |v|/(gamma*|g|)={ratio:.2f}"
                )
    else:
        # Pure gradient ascent (no flow)
        lr = step_size if step_size is not None else gamma / steps

        if verbose:
            logger.info(f"Gradient ascent: steps={steps}, lr={lr:.4f}")
            g_norms = []

        for i in range(steps):
            # Guidance gradient
            g = compute_guidance(
                z,
                target,
                surrogate,
                loss_fn,
                c=c,
                clip_norm=clip_norm,
                normalize=normalize,
            )

            # Collect diagnostics at first, middle, and last steps
            if verbose and i in (0, steps // 2, steps - 1):
                g_norm = torch.norm(g.reshape(n_valid, -1), dim=1).mean().item()
                g_norms.append((i, g_norm))

            # Pure gradient descent (minimize loss = improve properties)
            z = z - lr * g
            z = z.detach()

        # Log diagnostics
        if verbose and g_norms:
            logger.info(f"Gradient ascent diagnostics (lr={lr}):")
            for step, g_n in g_norms:
                logger.info(f"  step {step}: |g|={g_n:.4e}")

    # Predict properties after optimization (SurrogateHead accepts 3D)
    with torch.no_grad():
        pred_after = surrogate(z, c).detach().cpu().numpy()

    # Decode optimized molecules
    with torch.no_grad():
        ids = (
            decode_greedy_with_ids(
                vae,
                z,
                bos_id=vocab.bos_id,
                eos_id=vocab.eos_id,
                pad_id=vocab.pad_id,
                max_len=max_len,
            )
            .detach()
            .cpu()
            .numpy()
        )

    # Build results
    results: List[Dict] = []
    for idx, orig_idx in enumerate(valid_indices):
        decoded = decode_ids(ids[idx].tolist(), vocab, representation=representation)
        if representation == "smiles":
            out_smi = decoded
            out_sf = None
        else:
            out_sf = decoded
            out_smi = selfies_to_smiles(out_sf) or ""

        cs = canonicalize_smiles(out_smi) if out_smi else None

        result: Dict = {
            "input_smiles": input_smiles[orig_idx],
            "output_smiles": cs if cs is not None else out_smi,
            "valid": bool(cs is not None),
        }
        if out_sf is not None:
            result["output_selfies"] = out_sf

        # Add property predictions
        for j in range(pred_before.shape[1]):
            result[f"pred_prop_{j}_before"] = pred_before[idx, j]
            result[f"pred_prop_{j}_after"] = pred_after[idx, j]
            result[f"pred_prop_{j}_delta"] = pred_after[idx, j] - pred_before[idx, j]

        results.append(result)

    return results


def main(config_path: str = "configs/experiments/optimization_local.yaml") -> None:
    """Optimize molecules toward target properties.

    This script performs property-guided optimization using the SurrogateHead
    from the fine-tuned VAE checkpoint (no separate surrogate training needed):
    1. Encode input molecules to latent z_0
    2. Add noise: z_t = z_0 + sigma * noise
    3. Apply guided flow integration using SurrogateHead
    4. Decode to improved molecules
    5. Compare properties before/after

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg.get("seed", 42))

    logger.info(f"Loaded config: {config_path}")
    logger.info("=== Property-Guided Optimization ===")

    # Load dataset for vocab
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path", "data/raw/esol.csv")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    y_cols = data_cfg.get("y_cols", ["measured log solubility in mols per litre"])
    max_len = data_cfg.get("max_len", 128)
    representation = data_cfg.get("representation", "smiles")

    ds, _ = load_csv_dataset(
        csv_path,
        smiles_col,
        y_cols,
        max_len,
        seed=cfg.get("seed", 42),
        representation=representation,
    )
    vocab = ds.vocab

    # Load VAE checkpoint (contains both VAE and SurrogateHead)
    vae_checkpoint_path = cfg["vae"]["checkpoint_path"]
    logger.info(f"Loading VAE and SurrogateHead from {vae_checkpoint_path}")

    vae_cfg = VAEConfig(
        vocab_size=len(vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    vae = SmilesTokenVAE(vae_cfg, pad_id=vocab.pad_id)
    checkpoint = torch.load(vae_checkpoint_path, map_location="cpu")
    vae.load_state_dict(checkpoint["model"])
    vae.eval()
    logger.info("VAE loaded")

    # Load SurrogateHead from the same checkpoint
    surrogate_cfg = cfg.get("surrogate", {})
    surrogate = SurrogateHead(
        K=vae_cfg.K,
        d_latent=vae_cfg.d_latent,
        out_dim=surrogate_cfg.get("out_dim", len(y_cols)),
        cond_dim=surrogate_cfg.get("cond_dim", 0),
        hidden_dim=surrogate_cfg.get("hidden_dim", 256),
        aggregation=surrogate_cfg.get("aggregation", "mean"),
        dropout=surrogate_cfg.get("dropout", 0.1),
    )

    if "surrogate_head" in checkpoint:
        surrogate.load_state_dict(checkpoint["surrogate_head"])
        logger.info("SurrogateHead loaded from VAE checkpoint")
    else:
        raise ValueError(
            f"VAE checkpoint at {vae_checkpoint_path} does not contain surrogate_head. "
            "Make sure to use a fine-tuned VAE checkpoint with property supervision."
        )
    surrogate.eval()

    # Load flow
    flow_checkpoint = cfg["flow"]["checkpoint_path"]
    logger.info(f"Loading flow from {flow_checkpoint}")

    flow_cfg = FlowConfig(
        K=cfg["flow"].get("K", vae_cfg.K),
        d_latent=cfg["flow"].get("d_latent", vae_cfg.d_latent),
        d_model=cfg["flow"].get("d_model", 256),
        nhead=cfg["flow"].get("nhead", 8),
        layers=cfg["flow"].get("layers", 10),
        dim_ff=cfg["flow"].get("dim_ff", 1024),
        dropout=cfg["flow"].get("dropout", 0.1),
        time_dim=cfg["flow"].get("time_dim", 128),
    )

    flow = LatentFlowPrior(flow_cfg)
    flow_ckpt = torch.load(flow_checkpoint, map_location="cpu")
    flow.load_state_dict(flow_ckpt["model"])
    flow.eval()
    logger.info("Flow loaded")

    # Move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = vae.to(device)
    flow = flow.to(device)
    surrogate = surrogate.to(device)
    logger.info(f"Using device: {device}")

    # Load input molecules
    input_csv = cfg["optimization"]["input_csv"]
    input_col = cfg["optimization"].get("smiles_col", "smiles")
    n_samples = cfg["optimization"].get("n_samples", None)

    logger.info(f"Loading input molecules from {input_csv}")
    df_input = pd.read_csv(input_csv)
    input_smiles = df_input[input_col].astype(str).tolist()

    if n_samples is not None:
        input_smiles = input_smiles[:n_samples]

    logger.info(f"Loaded {len(input_smiles)} input molecules")

    # Optimization parameters
    steps = cfg["optimization"].get("steps", 30)
    gamma = cfg["guidance"]["gamma"]
    sigma = cfg["optimization"]["sigma"]
    target_values = cfg["guidance"]["target"]
    clip_norm = cfg["guidance"].get("clip_norm", None)
    normalize = cfg["guidance"].get("normalize", False)

    target = torch.tensor([target_values], device=device, dtype=torch.float32)

    # Handle optional conditions
    c = None
    if surrogate_cfg.get("cond_dim", 0) > 0:
        cond_values = cfg["guidance"].get("conditions")
        if cond_values is not None:
            c = torch.tensor([cond_values], device=device, dtype=torch.float32)

    logger.info(f"Target properties: {target_values}")
    logger.info(f"Guidance strength (gamma): {gamma}")
    logger.info(f"Noise level (sigma): {sigma}")
    logger.info(f"Integration steps: {steps}")

    # Optimize
    results = optimize_molecules(
        vae=vae,
        flow=flow,
        surrogate=surrogate,
        vocab=vocab,
        input_smiles=input_smiles,
        target=target,
        gamma=gamma,
        sigma=sigma,
        steps=steps,
        c=c,
        clip_norm=clip_norm,
        normalize=normalize,
        representation=representation,
    )

    # Save results
    output_dir = cfg.get("output", {}).get("dir", "outputs/optimization")
    os.makedirs(output_dir, exist_ok=True)
    output_csv = os.path.join(output_dir, "optimized.csv")

    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)

    # Report statistics
    n_valid = df["valid"].sum()
    validity = n_valid / len(df) * 100 if len(df) > 0 else 0

    logger.info(f"Optimized {len(df)} molecules")
    logger.info(f"Valid outputs: {n_valid}/{len(df)} ({validity:.1f}%)")
    logger.info(f"Results saved to {output_csv}")
