"""SMILES dataset for VAE training with tokenization."""

import os
from dataclasses import dataclass
from typing import List, Tuple, Dict
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from moltenflow.tokenizer.tokenizer import (
    Vocab,
    build_vocab,
    encode,
    smiles_to_selfies,
)
from moltenflow.data.data_utils import canonicalize_smiles

DATASET_URLS = {
    # Common MoleculeNet-style small property sets (public URLs used by many tutorials).
    "esol": "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv",
    "freesolv": "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/freesolv.csv",
    "lipophilicity": "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv",
    # Large-scale molecular generation benchmark
    "zinc250k": "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv",
}


def download_dataset(name: str, out_csv: str) -> None:
    """Download a dataset from a known URL.

    Args:
        name: Dataset name (esol, freesolv, lipophilicity)
        out_csv: Output path for CSV file

    Raises:
        ValueError: If dataset name is unknown
    """
    import urllib.request

    url = DATASET_URLS.get(name.lower())
    if url is None:
        raise ValueError(f"Unknown dataset '{name}'. Known: {list(DATASET_URLS)}")
    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    urllib.request.urlretrieve(url, out_csv)


@dataclass
class SmilesDataset:
    """Dataset for SMILES strings with tokenization and properties.

    Used for VAE training with sequence models.

    Attributes:
        smiles: List of SMILES strings
        y: Property values of shape (N, P)
        vocab: Tokenizer vocabulary
        max_len: Maximum sequence length
        y_cols: Names of property columns
    """

    smiles: List[str]
    y: np.ndarray  # (N, P) float32
    vocab: Vocab
    max_len: int
    y_cols: List[str]
    # How the sequence was tokenized for model training.
    # - "smiles": tokenize SMILES directly
    # - "selfies": tokenize SELFIES strings derived from SMILES
    representation: str = "smiles"
    # Sequences in the chosen representation (aligned with `smiles`).
    # For representation="smiles", this is usually identical to `smiles`.
    sequences: List[str] | None = None

    def __len__(self):
        return len(self.smiles)


def load_csv_dataset(
    csv_path: str,
    smiles_col: str,
    y_cols: List[str],
    max_len: int,
    min_freq: int = 1,
    seed: int = 7,
    drop_invalid: bool = True,
    representation: str = "smiles",
    drop_nan: bool = True,
) -> Tuple[SmilesDataset, Dict[str, np.ndarray]]:
    """Load SMILES dataset from CSV for VAE training.

    Args:
        csv_path: Path to CSV file
        smiles_col: Name of SMILES column
        y_cols: Names of property columns
        max_len: Maximum sequence length for tokenization
        min_freq: Minimum token frequency for vocabulary
        seed: Random seed for splits
        drop_invalid: Whether to drop invalid SMILES

    Returns:
        Tuple of (SmilesDataset, splits dict with train/val/test indices)

    Raises:
        ValueError: If required columns are missing
    """
    df = pd.read_csv(csv_path)

    if smiles_col not in df.columns:
        raise ValueError(f"smiles_col='{smiles_col}' not found. Columns={list(df.columns)}")
    for c in y_cols:
        if c not in df.columns:
            raise ValueError(f"y col '{c}' not found. Columns={list(df.columns)}")

    smiles_raw = df[smiles_col].astype(str).tolist()

    rep = representation.strip().lower()

    smiles: List[str] = []
    sequences: List[str] = []
    ys: List[np.ndarray] = []
    for i, s in enumerate(smiles_raw):
        cs = canonicalize_smiles(s)
        if cs is None:
            if drop_invalid:
                continue
            cs = s
        if rep == "smiles":
            seq = cs
        elif rep == "selfies":
            sf = smiles_to_selfies(cs)
            if sf is None:
                if drop_invalid:
                    continue
                # If we don't drop invalid, we still need a sequence. Skip.
                continue
            seq = sf
        else:
            raise ValueError(
                f"Unknown representation '{representation}'. Expected 'smiles' or 'selfies'."
            )
        rowy = df.loc[df.index[i], y_cols].values.astype(np.float32)
        if np.any(pd.isna(rowy)) and drop_nan:
            # drop missing labels (keep simple + robust)
            continue
        smiles.append(cs)
        sequences.append(seq)
        ys.append(rowy)

    y = np.stack(ys, axis=0).astype(np.float32)
    vocab = build_vocab(sequences, min_freq=min_freq, representation=rep)

    # splits
    idx = np.arange(len(smiles))
    train_idx, test_idx = train_test_split(idx, test_size=0.1, random_state=seed, shuffle=True)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.1, random_state=seed, shuffle=True)

    splits = {"train": train_idx, "val": val_idx, "test": test_idx}

    ds = SmilesDataset(
        smiles=smiles,
        y=y,
        vocab=vocab,
        max_len=max_len,
        y_cols=y_cols,
        representation=rep,
        sequences=sequences,
    )
    return ds, splits


def batchify(
    ds: SmilesDataset,
    indices: np.ndarray,
    batch_size: int,
    shuffle: bool,
    seed: int = 0,
):
    """Generate batches from SmilesDataset.

    Args:
        ds: SmilesDataset instance
        indices: Indices to sample from
        batch_size: Batch size
        shuffle: Whether to shuffle indices
        seed: Random seed for shuffling

    Yields:
        Dict with keys: smiles (list), x (tokenized array), y (properties)
    """
    rng = np.random.default_rng(seed)
    idx = indices.copy()
    if shuffle:
        rng.shuffle(idx)
    for start in range(0, len(idx), batch_size):
        b = idx[start : start + batch_size]
        smiles = [ds.smiles[i] for i in b]
        seqs = (
            [ds.sequences[i] for i in b] if ds.sequences is not None else [ds.smiles[i] for i in b]
        )
        x = np.stack(
            [encode(s, ds.vocab, ds.max_len, representation=ds.representation) for s in seqs],
            axis=0,
        ).astype(np.int64)
        y = ds.y[b]
        out = {"smiles": smiles, "x": x, "y": y}
        if ds.representation == "selfies":
            out["selfies"] = seqs
        yield out
