"""VAE training utilities.

This module provides functions for training Variational Autoencoders on
SMILES sequences with optional property supervision.
"""

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.training.losses import MaskedMSELoss
from moltenflow.utils.kl import kl_diag_gaussian
from moltenflow.data.smiles_dataset import SmilesDataset, batchify, load_csv_dataset
from moltenflow.utils.io import save_json
from moltenflow.utils.device import get_device
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


@dataclass
class TrainingHistory:
    """Container for training history metrics."""

    epochs: List[Dict[str, Any]]

    def to_list(self) -> List[Dict[str, Any]]:
        """Return history as list of epoch dicts."""
        return self.epochs

    def to_dict(self) -> Dict[str, Any]:
        """Return history as dict for JSON serialization."""
        return {"epochs": self.epochs}


def train_vae(
    ds: SmilesDataset,
    splits: Dict[str, np.ndarray],
    model: SmilesTokenVAE,
    workdir: str,
    epochs: int = 40,
    batch_size: int = 256,
    lr: float = 2e-4,
    lr_schedule: str = "cosine",
    lr_min: float = 0.0,
    beta_max: float = 0.12,
    beta_warmup_frac: float = 0.35,
    grad_clip: float = 1.0,
    seed: int = 7,
    surrogate_head: Optional[nn.Module] = None,
    property_weight: float = 0.0,
    freeze_decoder: bool = False,
    return_history: bool = False,
) -> Optional[TrainingHistory]:
    """Train a VAE model on SMILES sequences.

    Trains a Variational Autoencoder with optional property supervision.
    The training loss combines:
    - Reconstruction loss (cross-entropy over vocabulary)
    - KL divergence loss (with beta annealing)
    - Optional property prediction loss (if surrogate_head provided)

    Args:
        ds: SmilesDataset containing tokenized sequences and properties
        splits: Dictionary with 'train', 'val', 'test' index arrays
        model: SmilesTokenVAE model to train
        workdir: Output directory for checkpoints and logs
        epochs: Number of training epochs
        batch_size: Batch size for training
        lr: Learning rate
        lr_schedule: Learning rate schedule ('cosine' or None)
        lr_min: Minimum learning rate for cosine schedule
        beta_max: Maximum KL weight (beta-VAE parameter)
        beta_warmup_frac: Fraction of training for beta warmup
        grad_clip: Gradient clipping value
        seed: Random seed for reproducibility
        surrogate_head: Optional property prediction head for joint training
        property_weight: Weight for property loss (0 = disabled)
        freeze_decoder: If True, freeze decoder parameters (encoder-only fine-tuning)
        return_history: If True, return training history

    Returns:
        TrainingHistory if return_history=True, else None

    Saves:
        - vae_best.pt: Best model checkpoint (lowest validation loss)
        - vae_final.pt: Final model checkpoint
        - training_history.json: Training metrics per epoch
    """
    os.makedirs(workdir, exist_ok=True)
    dev = get_device()
    model = model.to(dev)

    # Setup property prediction if enabled
    use_property_loss = surrogate_head is not None and property_weight > 0.0
    if use_property_loss:
        surrogate_head = surrogate_head.to(dev)
        property_loss_fn = MaskedMSELoss()

    # Freeze decoder if requested (for fine-tuning encoder only)
    if freeze_decoder:
        for param in model.dec.parameters():
            param.requires_grad = False
        for param in model.out.parameters():
            param.requires_grad = False
        logger.info("Decoder frozen for fine-tuning")

    # Collect trainable parameters
    params = list(model.parameters())
    if use_property_loss:
        params += list(surrogate_head.parameters())

    opt = torch.optim.AdamW(
        [p for p in params if p.requires_grad], lr=lr, betas=(0.9, 0.95), weight_decay=0.01
    )

    steps_per_epoch = max(1, (len(splits["train"]) + batch_size - 1) // batch_size)

    # Optional LR scheduling
    scheduler = None
    if lr_schedule is not None:
        name = lr_schedule.strip().lower()
        if name in {"cosine", "cosineannealing", "cosine_annealing"}:
            # Step-wise cosine decay from `lr` to `lr_min` over the full training run
            t_max = max(1, epochs * steps_per_epoch)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=t_max, eta_min=lr_min)
            logger.info(f"Using CosineAnnealingLR: T_max={t_max} steps, lr_min={lr_min}")
        else:
            raise ValueError(f"Unknown lr_schedule='{lr_schedule}'. Supported: 'cosine'.")

    warmup_steps = int(epochs * steps_per_epoch * beta_warmup_frac)
    global_step = 0
    best_val = float("inf")
    history: List[Dict[str, Any]] = []

    for epoch in range(1, epochs + 1):
        model.train()
        pbar = tqdm(
            batchify(ds, splits["train"], batch_size, shuffle=True, seed=seed + epoch),
            desc=f"VAE train {epoch}/{epochs}",
        )
        tot = 0.0
        tot_prop = 0.0
        n = 0
        for batch in pbar:
            x = torch.tensor(batch["x"], device=dev)
            out = model(x)
            logits = out["logits"]
            x_tgt = out["x_tgt"]
            mu, logvar = out["mu"], out["logvar"]
            z = out["z"]

            V = logits.size(-1)
            loss_rec = F.cross_entropy(
                logits.reshape(-1, V), x_tgt.reshape(-1), ignore_index=model.pad_id
            )

            beta = beta_max * min(1.0, global_step / warmup_steps) if warmup_steps > 0 else beta_max
            loss_kl = kl_diag_gaussian(mu, logvar).mean() / (mu.size(1) * mu.size(2))

            loss = loss_rec + beta * loss_kl

            # Add property prediction loss if enabled
            loss_prop = torch.tensor(0.0, device=dev)
            if use_property_loss:
                y = torch.tensor(batch["y"], device=dev, dtype=torch.float32)
                # Support optional conditions from batch
                c = None
                if "c" in batch:
                    c = torch.tensor(batch["c"], device=dev, dtype=torch.float32)
                y_pred = surrogate_head(z, c)
                loss_prop = property_loss_fn(y_pred, y)
                loss = loss + property_weight * loss_prop

            opt.zero_grad(set_to_none=True)
            loss.backward()

            # Clip gradients for all trainable parameters
            if use_property_loss:
                torch.nn.utils.clip_grad_norm_([p for p in params if p.requires_grad], grad_clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            opt.step()
            if scheduler is not None:
                scheduler.step()

            global_step += 1
            tot += float(loss.detach().cpu())
            tot_prop += float(loss_prop.detach().cpu())
            n += 1

            postfix_dict = {
                "loss": tot / n,
                "rec": float(loss_rec.detach().cpu()),
                "kl": float(loss_kl.detach().cpu()),
                "beta": beta,
                "lr": opt.param_groups[0]["lr"],
            }
            if use_property_loss:
                postfix_dict["prop"] = tot_prop / n

            pbar.set_postfix(postfix_dict)

        beta_eval = (
            beta_max * min(1.0, global_step / warmup_steps) if warmup_steps > 0 else beta_max
        )

        val = eval_vae(
            ds,
            splits["val"],
            model,
            batch_size=batch_size,
            surrogate_head=surrogate_head if use_property_loss else None,
            property_weight=property_weight if use_property_loss else 0.0,
            beta=beta_eval,
        )

        save_json(os.path.join(workdir, "vae_metrics.json"), val | {"epoch": epoch})

        # Record history
        epoch_metrics = {
            "epoch": epoch,
            "train_loss": tot / n if n > 0 else 0.0,
            "val_loss": val["loss"],
            "val_rec": val["rec"],
            "val_kl": val["kl"],
        }
        if use_property_loss:
            epoch_metrics["train_prop"] = tot_prop / n if n > 0 else 0.0
            epoch_metrics["val_prop"] = val.get("prop", 0.0)
        history.append(epoch_metrics)

        # Save checkpoints
        checkpoint = {"model": model.state_dict()}
        if use_property_loss:
            checkpoint["surrogate_head"] = surrogate_head.state_dict()

        torch.save(checkpoint, os.path.join(workdir, "vae_last.pt"))
        if val["loss"] < best_val:
            best_val = val["loss"]
            torch.save(checkpoint, os.path.join(workdir, "vae_best.pt"))

    # Save training history
    training_history = TrainingHistory(epochs=history)
    save_json(os.path.join(workdir, "training_history.json"), training_history.to_dict())

    if return_history:
        return training_history
    return None


def eval_vae(
    ds: SmilesDataset,
    indices: np.ndarray,
    model: SmilesTokenVAE,
    batch_size: int = 256,
    surrogate_head: Optional[nn.Module] = None,
    property_weight: float = 0.0,
    beta: float = 1.0,
):
    """Evaluate VAE model on a dataset split.

    Computes reconstruction loss, KL divergence, and optionally property
    prediction loss on the specified indices.

    Args:
        ds: SmilesDataset to evaluate on
        indices: Array of sample indices to evaluate
        model: Trained SmilesTokenVAE model
        batch_size: Batch size for evaluation
        surrogate_head: Optional property prediction head
        property_weight: Weight for property loss (0 = disabled)
        beta: KL weight for evaluation

    Returns:
        Dictionary with keys:
        - loss: Total loss (reconstruction + beta * KL + property)
        - rec: Reconstruction loss
        - kl: KL divergence loss
        - prop: Property prediction loss (if enabled)
    """
    dev = next(model.parameters()).device
    model.eval()
    if surrogate_head is not None:
        surrogate_head.eval()

    use_property_loss = surrogate_head is not None and property_weight > 0.0
    if use_property_loss:
        property_loss_fn = MaskedMSELoss()

    losses, recs, kls, props = [], [], [], []
    with torch.no_grad():
        for batch in batchify(ds, indices, batch_size, shuffle=False):
            x = torch.tensor(batch["x"], device=dev)
            out = model(x)
            logits = out["logits"]
            x_tgt = out["x_tgt"]
            mu, logvar = out["mu"], out["logvar"]
            z = out["z"]

            V = logits.size(-1)
            loss_rec = F.cross_entropy(
                logits.reshape(-1, V), x_tgt.reshape(-1), ignore_index=model.pad_id
            )
            loss_kl = kl_diag_gaussian(mu, logvar).mean() / (mu.size(1) * mu.size(2))
            loss = loss_rec + beta * loss_kl

            if use_property_loss:
                y = torch.tensor(batch["y"], device=dev, dtype=torch.float32)
                # Support optional conditions from batch
                c = None
                if "c" in batch:
                    c = torch.tensor(batch["c"], device=dev, dtype=torch.float32)
                y_pred = surrogate_head(z, c)
                loss_prop = property_loss_fn(y_pred, y)
                loss = loss + property_weight * loss_prop
                props.append(float(loss_prop.cpu()))

            losses.append(float(loss.cpu()))
            recs.append(float(loss_rec.cpu()))
            kls.append(float(loss_kl.cpu()))

    result = {
        "loss": float(np.mean(losses)),
        "rec": float(np.mean(recs)),
        "kl": float(np.mean(kls)),
        "beta": float(beta),
    }

    if use_property_loss:
        result["prop"] = float(np.mean(props))

    return result


def train_standalone_surrogate(
    ds: SmilesDataset,
    splits: Dict[str, np.ndarray],
    vae: SmilesTokenVAE,
    surrogate_head: nn.Module,
    workdir: str,
    epochs: int = 30,
    batch_size: int = 256,
    lr: float = 1e-4,
    grad_clip: float = 1.0,
    seed: int = 7,
    return_history: bool = False,
) -> Optional[TrainingHistory]:
    """Train a surrogate head on frozen VAE latents (no latent space organization).

    This is an ablation where:
    - The VAE encoder is completely frozen
    - Latents are extracted and the surrogate is trained to predict properties
    - The VAE decoder and latent space are NOT organized around properties

    Args:
        ds: Dataset with molecules and properties
        splits: Train/val/test split indices
        vae: Pretrained VAE (will be frozen)
        surrogate_head: Surrogate head to train
        workdir: Output directory
        epochs: Number of training epochs
        batch_size: Batch size
        lr: Learning rate
        grad_clip: Gradient clipping norm
        seed: Random seed
        return_history: If True, return training history

    Returns:
        TrainingHistory if return_history=True, else None
    """
    os.makedirs(workdir, exist_ok=True)
    dev = get_device()

    # Move models to device
    vae = vae.to(dev)
    surrogate_head = surrogate_head.to(dev)

    # Freeze VAE completely
    vae.eval()
    for param in vae.parameters():
        param.requires_grad = False

    # Only train surrogate head
    opt = torch.optim.AdamW(
        surrogate_head.parameters(),
        lr=lr,
        betas=(0.9, 0.95),
        weight_decay=0.01,
    )

    loss_fn = MaskedMSELoss()
    best_val_loss = float("inf")
    history: List[Dict[str, Any]] = []

    logger.info("Training standalone surrogate head on frozen VAE latents...")
    logger.info(f"  Surrogate parameters: {sum(p.numel() for p in surrogate_head.parameters()):,}")

    for epoch in range(1, epochs + 1):
        surrogate_head.train()

        pbar = tqdm(
            batchify(ds, splits["train"], batch_size, shuffle=True, seed=seed + epoch),
            desc=f"Surrogate train {epoch}/{epochs}",
        )

        tot_loss = 0.0
        n_batches = 0

        for batch in pbar:
            x = torch.tensor(batch["x"], device=dev)
            y = torch.tensor(batch["y"], device=dev, dtype=torch.float32)

            # Extract latents from frozen VAE (no gradient through encoder)
            with torch.no_grad():
                z, mu, logvar = vae.encode(x)

            # Forward through surrogate head (with gradients)
            c = None
            if "c" in batch:
                c = torch.tensor(batch["c"], device=dev, dtype=torch.float32)
            y_pred = surrogate_head(mu, c)

            # Compute loss
            loss = loss_fn(y_pred, y)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(surrogate_head.parameters(), grad_clip)
            opt.step()

            tot_loss += float(loss.detach().cpu())
            n_batches += 1
            pbar.set_postfix({"loss": tot_loss / n_batches})

        # Validation
        val_loss = eval_standalone_surrogate(ds, splits["val"], vae, surrogate_head, batch_size)

        # Record history
        epoch_metrics = {
            "epoch": epoch,
            "train_loss": tot_loss / n_batches if n_batches > 0 else 0.0,
            "val_loss": val_loss,
        }
        history.append(epoch_metrics)

        # Save checkpoints
        checkpoint = {"surrogate_head": surrogate_head.state_dict()}
        torch.save(checkpoint, os.path.join(workdir, "surrogate_last.pt"))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, os.path.join(workdir, "surrogate_best.pt"))

        logger.info(
            f"Epoch {epoch}: train_loss={tot_loss / n_batches:.4f}, val_loss={val_loss:.4f}"
        )

    # Save training history
    training_history = TrainingHistory(epochs=history)
    save_json(os.path.join(workdir, "surrogate_training_history.json"), training_history.to_dict())

    if return_history:
        return training_history
    return None


def eval_standalone_surrogate(
    ds: SmilesDataset,
    indices: np.ndarray,
    vae: SmilesTokenVAE,
    surrogate_head: nn.Module,
    batch_size: int = 256,
) -> float:
    """Evaluate standalone surrogate head on frozen VAE latents.

    Args:
        ds: Dataset
        indices: Indices to evaluate
        vae: Frozen VAE
        surrogate_head: Surrogate head to evaluate
        batch_size: Batch size

    Returns:
        Mean validation loss
    """
    dev = next(vae.parameters()).device
    vae.eval()
    surrogate_head.eval()

    loss_fn = MaskedMSELoss()
    losses = []

    with torch.no_grad():
        for batch in batchify(ds, indices, batch_size, shuffle=False):
            x = torch.tensor(batch["x"], device=dev)
            y = torch.tensor(batch["y"], device=dev, dtype=torch.float32)

            # Extract latents from frozen VAE
            _, mu, _ = vae.encode(x)

            # Forward through surrogate head
            c = None
            if "c" in batch:
                c = torch.tensor(batch["c"], device=dev, dtype=torch.float32)
            y_pred = surrogate_head(mu, c)

            loss = loss_fn(y_pred, y)
            losses.append(float(loss.cpu()))

    return float(np.mean(losses)) if losses else float("nan")


def main(config_path: str = "configs/vae.yaml") -> None:
    """Train VAE from configuration file.

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg["train"]["seed"])

    logger.info(f"Loading config from {config_path}")

    # Load dataset
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path", "data/raw/esol.csv")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    y_cols = data_cfg.get("y_cols", ["measured log solubility in mols per litre"])
    max_len = data_cfg.get("max_len", 128)
    representation = data_cfg.get("representation", "smiles")

    logger.info(f"Loading dataset from {csv_path}")
    ds, splits = load_csv_dataset(
        csv_path,
        smiles_col,
        y_cols,
        max_len,
        seed=cfg["train"]["seed"],
        representation=representation,
    )
    logger.info(
        f"Dataset: {len(ds)} molecules, vocab size: {len(ds.vocab.id_to_token)}, max_len: {max_len}"
    )

    # Create model
    vae_cfg = VAEConfig(
        vocab_size=len(ds.vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    model = SmilesTokenVAE(vae_cfg, pad_id=ds.vocab.pad_id)
    logger.info(f"Created VAE with {sum(p.numel() for p in model.parameters())} parameters")

    # Setup output directory
    workdir = cfg.get("output", {}).get("workdir", "outputs/vae")
    os.makedirs(workdir, exist_ok=True)
    logger.info(f"Output directory: {workdir}")

    # Train
    logger.info("Starting VAE training...")
    train_vae(
        ds=ds,
        splits=splits,
        model=model,
        workdir=workdir,
        epochs=cfg["train"]["epochs"],
        batch_size=cfg["train"]["batch_size"],
        lr=cfg["train"]["lr"],
        beta_max=cfg["vae"].get("beta", 1.0),
        beta_warmup_frac=cfg["train"].get("beta_warmup_frac", 0.35),
        grad_clip=cfg["train"].get("grad_clip", 1.0),
        seed=cfg["train"]["seed"],
    )

    logger.info(f"VAE training complete! Models saved to {workdir}")
