"""Dataset splitting utilities.

This module provides functions for splitting datasets into train/validation/test
sets using random or scaffold-based strategies.
"""

from __future__ import annotations
from typing import Tuple
from collections import defaultdict
import os
import numpy as np
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold


def random_split(
    n: int, seed: int, train: float, val: float, test: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Split dataset indices randomly into train/val/test sets.

    Args:
        n: Total number of samples
        seed: Random seed for reproducibility
        train: Fraction for training set (must sum to 1.0 with val and test)
        val: Fraction for validation set
        test: Fraction for test set

    Returns:
        Tuple of (train_indices, val_indices, test_indices) as numpy arrays

    Raises:
        AssertionError: If fractions don't sum to 1.0
    """
    assert abs(train + val + test - 1.0) < 1e-6
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    n_train = int(train * n)
    n_val = int(val * n)
    train_idx = idx[:n_train]
    val_idx = idx[n_train : n_train + n_val]
    test_idx = idx[n_train + n_val :]
    return train_idx, val_idx, test_idx


def murcko_scaffold(smiles: str) -> str | None:
    """Extract Bemis-Murcko scaffold from SMILES.

    Args:
        smiles: Input SMILES string (should be valid/canonical)

    Returns:
        Canonical SMILES of the Murcko scaffold, or None if extraction fails
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    try:
        scaffold = MurckoScaffold.GetScaffoldForMol(mol)
        return Chem.MolToSmiles(scaffold, canonical=True)
    except Exception:
        return None


def scaffold_split(
    smiles_list: list[str],
    frac_train: float = 0.8,
    frac_val: float = 0.1,
    frac_test: float = 0.1,
    seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Scaffold-based dataset split (DeepChem-like behavior).

    Algorithm:
    1. Compute Murcko scaffold for each molecule
    2. Group molecule indices by scaffold
    3. Sort scaffold groups by size (descending), ties broken by smallest index
    4. Greedily assign groups to train/val/test to approximate target fractions

    This ensures molecules with the same scaffold are always in the same split,
    which is important for evaluating generalization to new scaffolds.

    The split is fully deterministic and reproducible based on the input data order.

    Args:
        smiles_list: List of SMILES strings
        frac_train: Fraction for training set (default 0.8)
        frac_val: Fraction for validation set (default 0.1)
        frac_test: Fraction for test set (default 0.1)
        seed: Random seed (unused, kept for API compatibility)

    Returns:
        Tuple of (train_indices, val_indices, test_indices) as numpy arrays

    Raises:
        ValueError: If fractions don't sum to 1.0
    """
    assert abs(frac_train + frac_val + frac_test - 1.0) < 1e-6

    n = len(smiles_list)

    # Step 1-2: Group indices by scaffold
    scaffold_to_indices = defaultdict(list)
    for i, smi in enumerate(smiles_list):
        scaffold = murcko_scaffold(smi)
        if scaffold is None:
            scaffold = smi  # Use original SMILES as fallback
        scaffold_to_indices[scaffold].append(i)

    # Step 3-4: Sort by size (desc), then by first index for determinism
    scaffold_groups = list(scaffold_to_indices.values())
    # Sort each group internally to ensure deterministic first element
    for group in scaffold_groups:
        group.sort()
    # Sort groups: largest first, ties broken by smallest index in group
    scaffold_groups.sort(key=lambda x: (-len(x), x[0]))

    # Step 5: Greedy assignment
    train_idx, val_idx, test_idx = [], [], []
    train_target = int(frac_train * n)
    val_target = int(frac_val * n)

    for group in scaffold_groups:
        # Assign to smallest set that hasn't reached target
        if len(train_idx) < train_target:
            train_idx.extend(group)
        elif len(val_idx) < val_target:
            val_idx.extend(group)
        else:
            test_idx.extend(group)

    return np.array(train_idx), np.array(val_idx), np.array(test_idx)


def save_split_indices(
    name: str,
    train_idx: np.ndarray,
    val_idx: np.ndarray,
    test_idx: np.ndarray,
    seed: int,
    splitter: str = "scaffold",
    splits_dir: str = "data/splits",
) -> str:
    """Save split indices to .npz file.

    Args:
        name: Dataset name
        train_idx, val_idx, test_idx: Split indices
        seed: Seed used for splitting
        splitter: Splitter type ("scaffold" or "random")
        splits_dir: Output directory

    Returns:
        Path to saved .npz file
    """
    os.makedirs(splits_dir, exist_ok=True)
    path = os.path.join(splits_dir, f"{name}_{splitter}_seed{seed}.npz")
    np.savez(path, train=train_idx, val=val_idx, test=test_idx)
    return path


def load_split_indices(path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Load split indices from .npz file.

    Args:
        path: Path to .npz file

    Returns:
        Tuple of (train_indices, val_indices, test_indices)
    """
    data = np.load(path)
    return data["train"], data["val"], data["test"]
