"""Training utilities for latent flow models.

This module provides functions for training continuous normalizing flow models
on VAE latent representations.
"""

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from moltenflow.models.latent_flow import LatentFlowPrior, FlowConfig
from moltenflow.models.vae import SmilesTokenVAE, VAEConfig
from moltenflow.data.smiles_dataset import SmilesDataset, batchify, load_csv_dataset
from moltenflow.utils.io import save_json
from moltenflow.utils.device import get_device
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed

logger = get_logger(__name__)


@dataclass
class FlowTrainingHistory:
    """Container for flow training history metrics."""

    epochs: List[Dict[str, Any]]

    def to_list(self) -> List[Dict[str, Any]]:
        """Return history as list of epoch dicts."""
        return self.epochs

    def to_dict(self) -> Dict[str, Any]:
        """Return history as dict for JSON serialization."""
        return {"epochs": self.epochs}


def train_flow_prior(
    ds: SmilesDataset,
    splits: Dict[str, np.ndarray],
    vae: SmilesTokenVAE,
    flow: LatentFlowPrior,
    workdir: str,
    epochs: int = 60,
    batch_size: int = 512,
    lr: float = 2e-4,
    grad_clip: float = 1.0,
    seed: int = 7,
    use_posterior_sample: bool = True,
    return_history: bool = False,
) -> Optional[FlowTrainingHistory]:
    """Train a continuous normalizing flow prior on VAE latents.

    Trains a flow model to learn the distribution of latent representations
    from a trained VAE. The flow learns to transform from a simple prior
    (standard normal) to the VAE posterior distribution.

    Args:
        ds: SmilesDataset containing tokenized sequences
        splits: Dictionary with 'train', 'val', 'test' index arrays
        vae: Trained VAE model (frozen during flow training)
        flow: LatentFlowPrior model to train
        workdir: Output directory for checkpoints and logs
        epochs: Number of training epochs
        batch_size: Batch size for training
        lr: Learning rate
        grad_clip: Gradient clipping value
        seed: Random seed for reproducibility
        use_posterior_sample: If True, use sampled z from VAE posterior.
                             If False, use mean mu.
        return_history: If True, return training history

    Returns:
        FlowTrainingHistory if return_history=True, else None

    Saves:
        - flow_best.pt: Best model checkpoint (lowest validation MSE)
        - flow_final.pt: Final model checkpoint
        - training_history.json: Training metrics per epoch
    """
    os.makedirs(workdir, exist_ok=True)
    dev = get_device()
    vae = vae.to(dev).eval()
    flow = flow.to(dev)
    opt = torch.optim.AdamW(flow.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.01)

    best = float("inf")
    history: List[Dict[str, Any]] = []
    for epoch in range(1, epochs + 1):
        flow.train()
        pbar = tqdm(
            batchify(ds, splits["train"], batch_size, shuffle=True, seed=seed + 1000 + epoch),
            desc=f"Flow train {epoch}/{epochs}",
        )
        tot = 0.0
        n = 0
        for batch in pbar:
            x = torch.tensor(batch["x"], device=dev)

            with torch.no_grad():
                z, mu, logvar = vae.encode(x)
                z1 = z if use_posterior_sample else mu

            z0 = torch.randn_like(z1)
            t = torch.rand(z1.size(0), device=dev)
            zt = (1.0 - t)[:, None, None] * z0 + t[:, None, None] * z1
            target = z1 - z0

            v = flow(zt, t)
            loss = F.mse_loss(v, target)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(flow.parameters(), grad_clip)
            opt.step()

            tot += float(loss.detach().cpu())
            n += 1
            pbar.set_postfix(loss=tot / n)

        val = eval_flow_prior(ds, splits["val"], vae, flow, batch_size=batch_size)
        torch.save({"model": flow.state_dict()}, os.path.join(workdir, "flow_last.pt"))
        if val["mse"] < best:
            best = val["mse"]
            torch.save({"model": flow.state_dict()}, os.path.join(workdir, "flow_best.pt"))
        save_json(os.path.join(workdir, "flow_metrics.json"), val | {"epoch": epoch})

        # Record history
        epoch_metrics = {
            "epoch": epoch,
            "train_loss": tot / n if n > 0 else 0.0,
            "val_loss": val["mse"],
            "val_mse": val["mse"],
        }
        history.append(epoch_metrics)

    # Save training history
    training_history = FlowTrainingHistory(epochs=history)
    save_json(os.path.join(workdir, "training_history.json"), training_history.to_dict())

    if return_history:
        return training_history
    return None


def eval_flow_prior(
    ds: SmilesDataset,
    indices: np.ndarray,
    vae: SmilesTokenVAE,
    flow: LatentFlowPrior,
    batch_size: int = 512,
):
    """Evaluate flow model on a dataset split.

    Computes mean squared error between predicted and target velocities
    on the specified indices.

    Args:
        ds: SmilesDataset to evaluate on
        indices: Array of sample indices to evaluate
        vae: Trained VAE model (frozen)
        flow: Trained LatentFlowPrior model
        batch_size: Batch size for evaluation

    Returns:
        Dictionary with key 'mse' containing mean squared error
    """
    dev = next(flow.parameters()).device
    vae = vae.to(dev).eval()
    flow.eval()
    mses = []
    with torch.no_grad():
        for batch in batchify(ds, indices, batch_size, shuffle=False):
            x = torch.tensor(batch["x"], device=dev)
            z, mu, logvar = vae.encode(x)
            z1 = z
            z0 = torch.randn_like(z1)
            t = torch.rand(z1.size(0), device=dev)
            zt = (1.0 - t)[:, None, None] * z0 + t[:, None, None] * z1
            target = z1 - z0
            v = flow(zt, t)
            mses.append(float(F.mse_loss(v, target).cpu()))
    return {"mse": float(np.mean(mses))}


def main(config_path: str = "configs/latent_model.yaml") -> None:
    """Train latent flow model from configuration file.

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(cfg["train"]["seed"])

    logger.info(f"Loading config from {config_path}")

    # Load dataset
    data_cfg = cfg.get("data", {})
    csv_path = data_cfg.get("csv_path", "data/raw/esol.csv")
    smiles_col = data_cfg.get("smiles_col", "smiles")
    y_cols = data_cfg.get("y_cols", ["measured log solubility in mols per litre"])
    max_len = data_cfg.get("max_len", 128)
    representation = data_cfg.get("representation", "smiles")

    logger.info(f"Loading dataset from {csv_path}")
    ds, splits = load_csv_dataset(
        csv_path,
        smiles_col,
        y_cols,
        max_len,
        seed=cfg["train"]["seed"],
        representation=representation,
    )
    logger.info(f"Dataset: {len(ds)} molecules")

    # Load VAE
    vae_checkpoint = cfg["vae"]["checkpoint_path"]
    logger.info(f"Loading VAE from {vae_checkpoint}")

    vae_cfg = VAEConfig(
        vocab_size=len(ds.vocab.id_to_token),
        max_len=max_len,
        d_model=cfg["vae"].get("d_model", 256),
        nhead=cfg["vae"].get("nhead", 8),
        enc_layers=cfg["vae"].get("enc_layers", 6),
        dec_layers=cfg["vae"].get("dec_layers", 6),
        dim_ff=cfg["vae"].get("dim_ff", 1024),
        dropout=cfg["vae"].get("dropout", 0.1),
        K=cfg["vae"].get("K", 8),
        d_latent=cfg["vae"].get("latent_dim", 128),
    )

    vae = SmilesTokenVAE(vae_cfg, pad_id=ds.vocab.pad_id)
    checkpoint = torch.load(vae_checkpoint, map_location="cpu")
    vae.load_state_dict(checkpoint["model"])
    logger.info("VAE loaded successfully")

    # Create flow model
    flow_cfg = FlowConfig(
        K=cfg["flow"].get("K", 8),
        d_latent=cfg["flow"].get("d_latent", 128),
        d_model=cfg["flow"].get("d_model", 256),
        nhead=cfg["flow"].get("nhead", 8),
        layers=cfg["flow"].get("layers", 10),
        dim_ff=cfg["flow"].get("dim_ff", 1024),
        dropout=cfg["flow"].get("dropout", 0.1),
        time_dim=cfg["flow"].get("time_dim", 128),
    )

    flow = LatentFlowPrior(flow_cfg)
    logger.info(f"Created flow model with {sum(p.numel() for p in flow.parameters())} parameters")

    # Setup output directory
    workdir = cfg.get("output", {}).get("workdir", "outputs/flow")
    os.makedirs(workdir, exist_ok=True)
    logger.info(f"Output directory: {workdir}")

    # Train
    logger.info("Starting flow training...")
    train_flow_prior(
        ds=ds,
        splits=splits,
        vae=vae,
        flow=flow,
        workdir=workdir,
        epochs=cfg["train"]["epochs"],
        batch_size=cfg["train"]["batch_size"],
        lr=cfg["train"]["lr"],
        grad_clip=cfg["train"].get("grad_clip", 1.0),
        seed=cfg["train"]["seed"],
        use_posterior_sample=cfg["train"].get("use_posterior_sample", True),
    )

    logger.info(f"Flow training complete! Models saved to {workdir}")
