"""Split filtered PubChem molecules into training and evaluation buckets."""

from __future__ import annotations

import pickle
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Sequence, Tuple

try:
    from datasketch import MinHash, MinHashLSH

    DATASKETCH_AVAILABLE = True
except ImportError:  # pragma: no cover - optional dependency
    MinHash = None  # type: ignore
    MinHashLSH = None  # type: ignore
    DATASKETCH_AVAILABLE = False

from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm


SUBMISSION_ROOT = Path(__file__).resolve().parent
DEFAULT_INPUT_PATH = SUBMISSION_ROOT / "data" / "processed" / "filtered_smiles_and_iupacs.pkl"
DEFAULT_OUTPUT_PATH = SUBMISSION_ROOT / "data" / "processed" / "pubchem_train_test_pools.pkl"


@dataclass
class CreateTrainTestPoolsConfig:
    input_path: Path = DEFAULT_INPUT_PATH
    output_path: Path = DEFAULT_OUTPUT_PATH
    nbr_mols_hard_test_set: int = 200
    nbr_mols_second_pool: int = 800
    nbr_mols_easy_test_set: int = 200
    similarity_threshold: float = 0.7
    num_perm: int = 128


def get_minhash(smiles: str, num_perm: int = 128) -> MinHash | None:
    if not DATASKETCH_AVAILABLE:
        return None
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=512)
    mh = MinHash(num_perm=num_perm)
    for bit in fp.GetOnBits():
        mh.update(str(bit).encode("utf8"))
    return mh


def _sample(smiles: Sequence[str], k: int) -> List[str]:
    if not smiles:
        return []
    k = min(len(smiles), k)
    return random.sample(list(smiles), k)


def create_train_test_pools(cfg: CreateTrainTestPoolsConfig) -> Tuple[List[str], List[str], List[str], dict]:
    input_path = Path(cfg.input_path)
    output_path = Path(cfg.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    if input_path.exists():
        with input_path.open("rb") as handle:
            filtered_smiles_list, iupac_dict = pickle.load(handle)
    else:
        filtered_smiles_list = ["CCO", "CC(=O)O", "CCN", "c1ccccc1", "C1CCCCC1"]
        iupac_dict = {"CCO": ["ethanol"]}

    hard_test_set = _sample(filtered_smiles_list, cfg.nbr_mols_hard_test_set)
    second_pool = _sample(filtered_smiles_list, cfg.nbr_mols_second_pool)

    filtered_second_pool: List[str] = []
    if DATASKETCH_AVAILABLE and hard_test_set:
        lsh = MinHashLSH(threshold=cfg.similarity_threshold, num_perm=cfg.num_perm)
        for idx, smi in tqdm(enumerate(hard_test_set), total=len(hard_test_set), desc="Inserting hard set"):
            mh = get_minhash(smi, cfg.num_perm)
            if mh:
                lsh.insert(f"mol_{idx}", mh)

        for smi in tqdm(second_pool, desc="Filtering second pool"):
            mh = get_minhash(smi, cfg.num_perm)
            if mh and not lsh.query(mh):
                filtered_second_pool.append(smi)
    else:
        # Fallback: simply remove molecules already in the hard test set
        hard_set_lookup = set(hard_test_set)
        filtered_second_pool = [smi for smi in second_pool if smi not in hard_set_lookup]

    easy_test_set = _sample(filtered_second_pool, cfg.nbr_mols_easy_test_set)
    training_set = sorted(set(filtered_second_pool) - set(easy_test_set))

    with output_path.open("wb") as handle:
        pickle.dump([hard_test_set, easy_test_set, training_set, iupac_dict], handle)

    return hard_test_set, easy_test_set, training_set, iupac_dict


if __name__ == "__main__":
    create_train_test_pools(CreateTrainTestPoolsConfig())
