"""Data loading backends for surrogate training.

This module provides unified interfaces for loading data from different sources:
- toy: Synthetic toy data
- rdkit_fp: RDKit fingerprints from CSV
- parquet: Processed parquet files
- druglike: Druglike property datasets (ESOL, Lipophilicity, etc.)
"""

from __future__ import annotations
from typing import Dict, Any, Tuple, Optional
import os
import torch
import pandas as pd

from moltenflow.data.latents import (
    generate_toy_latents,
    generate_toy_targets,
    generate_toy_conditions,
    get_latents,
)
from moltenflow.data.dataset import (
    load_csv_dataset,
    load_processed_dataset,
    load_property_dataset,
    DRUGLIKE_DATASET_CONFIGS,
)
from moltenflow.data.splits import scaffold_split, random_split, load_split_indices
from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)


def load_toy_data(
    cfg: Dict[str, Any], seed: int
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    """Load synthetic toy data for testing.

    Args:
        cfg: Data configuration dict with keys:
            - toy_n_samples: Number of samples
            - toy_latent_dim: Latent dimension
            - cond_dim: Conditional dimension (optional)
            - out_dim: Number of properties
        seed: Random seed

    Returns:
        Tuple of (z_all, y_all, c_all) where c_all may be None
    """
    n_samples = cfg.get("toy_n_samples", 5000)
    latent_dim = cfg.get("toy_latent_dim", 128)
    cond_dim = cfg.get("cond_dim", 0)
    n_properties = cfg["out_dim"]

    logger.info(f"Generating {n_samples} toy samples with dim={latent_dim}")
    z_all = generate_toy_latents(n_samples, latent_dim, seed=seed)

    # Generate conditions if needed
    if cond_dim > 0:
        c_all = generate_toy_conditions(n_samples, cond_dim, seed=seed + 1)
        logger.info(f"Generated toy conditions with dim={cond_dim}")
    else:
        c_all = None

    y_all = generate_toy_targets(z_all, n_properties, c=c_all, seed=seed)

    return z_all, y_all, c_all


def load_rdkit_fp_data(
    cfg: Dict[str, Any], cond_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    """Load molecular data and featurize with RDKit fingerprints.

    Args:
        cfg: Data configuration dict with keys:
            - dataset_path: Path to CSV file
            - smiles_col: SMILES column name
            - target_cols: List of target column names
            - cond_cols: List of condition column names (optional)
            - fp_radius: Fingerprint radius
            - fp_nbits: Fingerprint size
        cond_dim: Expected conditional dimension

    Returns:
        Tuple of (z_all, y_all, c_all) where c_all may be None
    """
    dataset_path = cfg.get("dataset_path", "data/ilthermo.csv")
    smiles_col = cfg.get("smiles_col", "smiles")
    target_cols = cfg.get("target_cols", ["co2_solubility", "viscosity"])
    cond_cols = cfg.get("cond_cols", [])

    logger.info(f"Loading dataset from {dataset_path}")
    mol_dataset = load_csv_dataset(dataset_path, smiles_col, target_cols)

    logger.info(f"Generating RDKit fingerprints for {len(mol_dataset.smiles)} molecules")
    z_all = get_latents(
        mol_dataset.smiles,
        backend="rdkit_fp",
        fp_radius=cfg.get("fp_radius", 2),
        fp_nbits=cfg.get("fp_nbits", 2048),
    )
    y_all = torch.from_numpy(mol_dataset.y.values.astype("float32"))

    # Load conditions from dataset if specified
    if cond_cols and cond_dim > 0:
        df = pd.read_csv(dataset_path)
        c_all = torch.from_numpy(df[cond_cols].values.astype("float32"))
        logger.info(f"Loaded {len(cond_cols)} condition columns: {cond_cols}")
    else:
        c_all = None

    return z_all, y_all, c_all


def load_parquet_data(
    cfg: Dict[str, Any], cond_dim: int, use_masked_loss: bool
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    """Load processed parquet data.

    Args:
        cfg: Data configuration dict with keys:
            - dataset_path: Path to parquet file
            - smiles_col: SMILES column name
            - target_cols: List of target column names
            - cond_cols: List of condition column names (optional)
            - fp_radius: Fingerprint radius
            - fp_nbits: Fingerprint size
        cond_dim: Expected conditional dimension
        use_masked_loss: Whether to keep NaN values for masked training

    Returns:
        Tuple of (z_all, y_all, c_all) where c_all may be None
    """
    dataset_path = cfg.get("dataset_path", "data/processed/co2_solubility.parquet")
    smiles_col = cfg.get("smiles_col", "smiles")
    target_cols = cfg.get("target_cols", ["x_CO2"])
    cond_cols = cfg.get("cond_cols", []) if cond_dim > 0 else None

    logger.info(f"Loading processed parquet from {dataset_path}")
    logger.info(f"Masked loss: {use_masked_loss}")

    # Load dataset with optional NaN handling
    mol_dataset, dropped_rows = load_processed_dataset(
        dataset_path,
        smiles_col=smiles_col,
        target_cols=target_cols,
        cond_cols=cond_cols,
        drop_nan=not use_masked_loss,  # Keep NaN if using masked loss
    )

    if use_masked_loss and dropped_rows == 0:
        logger.info(
            f"Loaded {len(mol_dataset.smiles)} samples (keeping NaN values for masked training)"
        )
    else:
        logger.info(f"Loaded {len(mol_dataset.smiles)} samples (dropped {dropped_rows} NaN rows)")

    logger.info(f"Generating RDKit fingerprints for {len(mol_dataset.smiles)} molecules")
    z_all = get_latents(
        mol_dataset.smiles,
        backend="rdkit_fp",
        fp_radius=cfg.get("fp_radius", 2),
        fp_nbits=cfg.get("fp_nbits", 2048),
    )

    y_all = torch.from_numpy(mol_dataset.y.values.astype("float32"))

    # Load conditions if present
    if mol_dataset.c is not None:
        c_all = torch.from_numpy(mol_dataset.c.values.astype("float32"))
        logger.info(f"Loaded {len(cond_cols)} condition columns: {cond_cols}")
    else:
        c_all = None

    return z_all, y_all, c_all


def load_druglike_data(
    cfg: Dict[str, Any],
) -> Tuple[torch.Tensor, torch.Tensor, None, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """Load druglike property dataset (ESOL, Lipophilicity, etc.).

    Args:
        cfg: Data configuration dict with keys:
            - dataset_name: Name of dataset (esol, lipophilicity, etc.)
            - splitter: Split method (scaffold or random)
            - split_seed: Random seed for splitting
            - target_cols: Optional list of target columns
            - fp_radius: Fingerprint radius
            - fp_nbits: Fingerprint size

    Returns:
        Tuple of (z_all, y_all, None, (train_idx, val_idx, test_idx))
    """
    dataset_name = cfg["dataset_name"]
    splitter = cfg.get("splitter", "scaffold")
    split_seed = cfg.get("split_seed", 42)

    logger.info(f"Loading druglike dataset: {dataset_name}")

    # Check if this is a joint dataset (already merged and saved as parquet)
    joint_path = f"data/processed/{dataset_name}.parquet"
    if os.path.exists(joint_path):
        # Load joint dataset directly
        df = pd.read_parquet(joint_path)
        target_cols = cfg.get("target_cols", None)
        if target_cols is None:
            # Infer target columns (all except smiles)
            target_cols = [col for col in df.columns if col != "smiles"]
        logger.info(f"Loaded joint dataset with targets: {target_cols}")
    else:
        # Load single property dataset
        df = load_property_dataset(dataset_name)
        if dataset_name in DRUGLIKE_DATASET_CONFIGS:
            target_cols = [DRUGLIKE_DATASET_CONFIGS[dataset_name]["target_name"]]
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

    smiles = df["smiles"].tolist()
    y_all = torch.from_numpy(df[target_cols].values.astype("float32"))
    if y_all.dim() == 1:
        y_all = y_all.unsqueeze(1)

    logger.info(f"Dataset size: {len(smiles)} molecules")
    logger.info(f"Target properties: {target_cols}")

    # Compute or load split indices
    split_path = f"data/splits/{dataset_name}_{splitter}_seed{split_seed}.npz"
    if os.path.exists(split_path):
        logger.info(f"Loading cached split from {split_path}")
        train_idx, val_idx, test_idx = load_split_indices(split_path)
    else:
        logger.info(f"Computing {splitter} split with seed={split_seed}")
        if splitter == "scaffold":
            train_idx, val_idx, test_idx = scaffold_split(smiles, seed=split_seed)
        else:
            train_idx, val_idx, test_idx = random_split(
                len(smiles), split_seed, train=0.8, val=0.1, test=0.1
            )
        # Save for reproducibility
        from moltenflow.data.splits import save_split_indices

        save_split_indices(dataset_name, train_idx, val_idx, test_idx, split_seed, splitter)

    logger.info(f"Split sizes: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}")

    # Featurize with RDKit fingerprints
    logger.info("Generating RDKit fingerprints...")
    z_all = get_latents(
        smiles,
        backend="rdkit_fp",
        fp_radius=cfg.get("fp_radius", 2),
        fp_nbits=cfg.get("fp_nbits", 2048),
    )

    # No conditions for druglike datasets
    return z_all, y_all, None, (train_idx, val_idx, test_idx)


def load_data_for_surrogate(
    data_cfg: Dict[str, Any], surrogate_cfg: Dict[str, Any], train_cfg: Dict[str, Any]
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    Optional[torch.Tensor],
    Optional[torch.Tensor],
]:
    """Load data for surrogate training from configured backend.

    Args:
        data_cfg: Data configuration
        surrogate_cfg: Surrogate model configuration
        train_cfg: Training configuration

    Returns:
        Tuple of (z_train, z_val, y_train, y_val, c_train, c_val)
    """
    backend = data_cfg.get("backend", "toy")
    cond_dim = surrogate_cfg.get("cond_dim", 0)
    seed = train_cfg["seed"]

    logger.info(f"Data backend: {backend}, Condition dim: {cond_dim}")

    # Load data based on backend
    if backend == "toy":
        z_all, y_all, c_all = load_toy_data(
            {**data_cfg, "out_dim": surrogate_cfg["out_dim"], "cond_dim": cond_dim}, seed
        )
        split_indices = None

    elif backend == "rdkit_fp":
        z_all, y_all, c_all = load_rdkit_fp_data(data_cfg, cond_dim)
        split_indices = None

    elif backend == "parquet":
        use_masked_loss = train_cfg.get("use_masked_loss", False)
        z_all, y_all, c_all = load_parquet_data(data_cfg, cond_dim, use_masked_loss)
        split_indices = None

    elif backend == "druglike":
        z_all, y_all, c_all, split_indices = load_druglike_data(data_cfg)

    else:
        raise ValueError(f"Unknown data backend: {backend}")

    # Train/val split
    if split_indices is not None:
        # Use pre-computed indices (druglike datasets)
        train_idx, val_idx, _ = split_indices
        z_train, z_val = z_all[train_idx], z_all[val_idx]
        y_train, y_val = y_all[train_idx], y_all[val_idx]
        c_train, c_val = None, None
    else:
        # Use random split for other backends
        n_total = len(z_all)
        n_train = int(0.9 * n_total)
        indices = torch.randperm(n_total, generator=torch.Generator().manual_seed(seed))

        z_train, z_val = z_all[indices[:n_train]], z_all[indices[n_train:]]
        y_train, y_val = y_all[indices[:n_train]], y_all[indices[n_train:]]

        if c_all is not None:
            c_train, c_val = c_all[indices[:n_train]], c_all[indices[n_train:]]
        else:
            c_train, c_val = None, None

    logger.info(f"Train samples: {len(z_train)}, Val samples: {len(z_val)}")

    return z_train, z_val, y_train, y_val, c_train, c_val
