"""Train surrogate property predictor."""

from __future__ import annotations
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

from moltenflow.models.surrogate import PropertySurrogate
from moltenflow.data.backends import load_data_for_surrogate
from moltenflow.data.transforms import TargetScaler
from moltenflow.training.losses import MaskedMSELoss
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.utils.seeds import set_seed
from moltenflow.utils.io import ensure_dir
from moltenflow.utils.device import get_device_str

logger = get_logger(__name__)


def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module,
    device: str,
    use_conditions: bool = False,
) -> float:
    """Train for one epoch.

    Args:
        model: Surrogate model
        dataloader: Training data loader
        optimizer: Optimizer
        loss_fn: Loss function
        device: Device to train on
        use_conditions: Whether dataloader includes conditions

    Returns:
        Average training loss for the epoch
    """
    model.train()
    total_loss = 0.0
    n_batches = 0

    for batch in dataloader:
        if use_conditions:
            z_batch, c_batch, y_batch = batch
            z_batch = z_batch.to(device)
            c_batch = c_batch.to(device)
            y_batch = y_batch.to(device)
        else:
            z_batch, y_batch = batch
            z_batch = z_batch.to(device)
            y_batch = y_batch.to(device)
            c_batch = None

        optimizer.zero_grad()
        y_pred = model(z_batch, c_batch)
        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: nn.Module,
    device: str,
    use_conditions: bool = False,
) -> float:
    """Evaluate model on validation data.

    Args:
        model: Surrogate model
        dataloader: Validation data loader
        loss_fn: Loss function
        device: Device to evaluate on
        use_conditions: Whether dataloader includes conditions

    Returns:
        Average validation loss
    """
    model.eval()
    total_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for batch in dataloader:
            if use_conditions:
                z_batch, c_batch, y_batch = batch
                z_batch = z_batch.to(device)
                c_batch = c_batch.to(device)
                y_batch = y_batch.to(device)
            else:
                z_batch, y_batch = batch
                z_batch = z_batch.to(device)
                y_batch = y_batch.to(device)
                c_batch = None

            y_pred = model(z_batch, c_batch)
            loss = loss_fn(y_pred, y_batch)

            total_loss += loss.item()
            n_batches += 1

    return total_loss / n_batches


def main(config_path: str = "configs/surrogate.yaml") -> None:
    """Train surrogate property predictor.

    Args:
        config_path: Path to YAML configuration file
    """
    cfg = load_yaml(config_path)
    set_seed(int(cfg["train"]["seed"]))

    logger.info(f"Loaded config: {config_path}")

    # Setup device
    device = get_device_str()
    logger.info(f"Using device: {device}")

    # Load data using backend system
    z_train, z_val, y_train, y_val, c_train, c_val = load_data_for_surrogate(
        cfg["data"], cfg["surrogate"], cfg["train"]
    )

    # Fit target scaler on training data
    scaler = TargetScaler.fit(y_train.numpy())
    y_train_scaled = torch.from_numpy(scaler.transform(y_train.numpy()))
    y_val_scaled = torch.from_numpy(scaler.transform(y_val.numpy()))

    logger.info(f"Target mean: {scaler.mean}, std: {scaler.std}")

    # Fit condition scaler on training data if conditions present
    if c_train is not None:
        cond_scaler = TargetScaler.fit(c_train.numpy())
        c_train_scaled = torch.from_numpy(cond_scaler.transform(c_train.numpy()))
        c_val_scaled = torch.from_numpy(cond_scaler.transform(c_val.numpy()))

        logger.info(f"Condition mean: {cond_scaler.mean}, std: {cond_scaler.std}")

        # Replace original conditions with scaled versions
        c_train, c_val = c_train_scaled, c_val_scaled
    else:
        cond_scaler = None

    # Create data loaders
    if c_train is not None:
        train_dataset = TensorDataset(z_train, c_train, y_train_scaled)
        val_dataset = TensorDataset(z_val, c_val, y_val_scaled)
        use_conditions = True
    else:
        train_dataset = TensorDataset(z_train, y_train_scaled)
        val_dataset = TensorDataset(z_val, y_val_scaled)
        use_conditions = False

    batch_size = cfg["train"]["batch_size"]
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    model = PropertySurrogate(
        in_dim=cfg["surrogate"]["in_dim"],
        out_dim=cfg["surrogate"]["out_dim"],
        cond_dim=cfg["surrogate"].get("cond_dim", 0),
        hidden_dims=cfg["surrogate"].get("hidden_dims", [256, 256]),
        dropout=cfg["surrogate"].get("dropout", 0.0),
    ).to(device)

    logger.info(f"Model: {model}")
    logger.info(f"Parameters: {sum(p.numel() for p in model.parameters())}")

    # Setup optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["train"]["lr"])

    # Use masked loss if configured (handles NaN targets)
    use_masked_loss = cfg["train"].get("use_masked_loss", False)
    if use_masked_loss:
        loss_fn = MaskedMSELoss()
        logger.info("Loss function: MaskedMSELoss (ignores NaN targets)")
    else:
        loss_type = cfg["train"].get("loss", "mse")
        if loss_type == "mse":
            loss_fn = nn.MSELoss()
        elif loss_type == "huber":
            loss_fn = nn.HuberLoss()
        else:
            raise ValueError(f"Unknown loss: {loss_type}")
        logger.info(f"Loss function: {loss_type}")

    # Training loop
    epochs = cfg["train"]["epochs"]
    best_val_loss = float("inf")

    logger.info(f"Starting training for {epochs} epochs")

    for epoch in tqdm(range(epochs), desc="Training"):
        train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device, use_conditions)
        val_loss = evaluate(model, val_loader, loss_fn, device, use_conditions)

        if (epoch + 1) % 10 == 0 or epoch == 0:
            logger.info(
                f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
            )

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss

    logger.info(f"Training complete. Best val loss: {best_val_loss:.4f}")

    # Save outputs
    checkpoint_path = Path(cfg["output"]["checkpoint_path"])
    scaler_path = Path(cfg["output"]["scaler_path"])

    ensure_dir(checkpoint_path.parent)
    ensure_dir(scaler_path.parent)

    # Save model checkpoint
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "config": cfg,
            "best_val_loss": best_val_loss,
        },
        checkpoint_path,
    )
    logger.info(f"Saved model checkpoint to {checkpoint_path}")

    # Save target scaler
    scaler.save(scaler_path)
    logger.info(f"Saved target scaler to {scaler_path}")

    # Save condition scaler if used
    if cond_scaler is not None:
        cond_scaler_path = Path(
            cfg["output"].get("cond_scaler_path", str(scaler_path.parent / "cond_scaler.json"))
        )
        ensure_dir(cond_scaler_path.parent)
        cond_scaler.save(cond_scaler_path)
        logger.info(f"Saved condition scaler to {cond_scaler_path}")

    logger.info("Surrogate training complete!")
