import numpy as np
from typing import Any, Dict, List, Iterator, Optional, Sequence, Tuple

def _generate_scaffold(smiles: str, include_chirality: bool = False) -> str:
  try:
    from rdkit import Chem
    from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
  except ModuleNotFoundError:
    raise ImportError("This function requires RDKit to be installed.")

  mol = Chem.MolFromSmiles(smiles)
  scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
  return scaffold


def split_scaffold(
      dataset,
      frac_train: float = 0.1,
  ) -> Tuple[List[int], List[int], List[int]]:
    # np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    scaffold_sets = generate_scaffolds(dataset)

    train_cutoff = frac_train * len(dataset)
    train_inds: List[int] = []
    valid_inds: List[int] = []

    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    
    train_df = dataset.iloc[train_inds]
    valid_df = dataset.iloc[valid_inds]
    return train_df, valid_df

def generate_scaffolds(
                         dataset,
                         log_every_n: int = 1000) -> List[List[int]]:
    scaffolds = {}
    data_len = len(dataset)

    for ind, row in dataset.iterrows():
      smiles = row["Drug"]
      if ind % log_every_n == 0:
        scaffold = _generate_scaffold(smiles)
      if scaffold not in scaffolds:
        scaffolds[scaffold] = [ind]
      else:
        scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    return scaffold_sets



def split_fp(
      dataset,
      frac_train: float = 0.8,
  ):
    try:
      from rdkit import Chem
      from rdkit.Chem import AllChem
    except ModuleNotFoundError:
      raise ImportError("This function requires RDKit to be installed.")

    # Compute fingerprints for all molecules.
    mols = [Chem.MolFromSmiles(datapoint["Drug"]) for idx, datapoint in dataset.iterrows()]
    fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in mols]

    # Split into two groups: training set and everything else.

    train_size = int(frac_train * len(dataset))
    test_size = len(dataset) - train_size 
    train_inds, test_inds = _split_fingerprints(fps, train_size,
                                                    test_size)
    
    return dataset.iloc[train_inds], dataset.iloc[test_inds]


def _split_fingerprints(fps: List, size1: int,
                        size2: int) :
  """This is called by FingerprintSplitter to divide a list of fingerprints into
  two groups.
  """
  assert len(fps) == size1 + size2
  from rdkit import DataStructs

  # Begin by assigning the first molecule to the first group.

  fp_in_group = [[fps[0]], []]
  indices_in_group: Tuple[List[int], List[int]] = ([0], [])
  remaining_fp = fps[1:]
  remaining_indices = list(range(1, len(fps)))
  max_similarity_to_group = [
      DataStructs.BulkTanimotoSimilarity(fps[0], remaining_fp),
      [0] * len(remaining_fp)
  ]
  # Return identity if no tuple to split to
  if size2 == 0:
    return ((list(range(len(fps)))), [])

  while len(remaining_fp) > 0:
    # Decide which group to assign a molecule to.

    group = 0 if len(fp_in_group[0]) / size1 <= len(
        fp_in_group[1]) / size2 else 1

    # Identify the unassigned molecule that is least similar to everything in
    # the other group.

    i = np.argmin(max_similarity_to_group[1 - group])

    # Add it to the group.

    fp = remaining_fp[i]
    fp_in_group[group].append(fp)
    indices_in_group[group].append(remaining_indices[i])

    # Update the data on unassigned molecules.

    similarity = DataStructs.BulkTanimotoSimilarity(fp, remaining_fp)
    max_similarity_to_group[group] = np.delete(
        np.maximum(similarity, max_similarity_to_group[group]), i)
    max_similarity_to_group[1 - group] = np.delete(
        max_similarity_to_group[1 - group], i)
    del remaining_fp[i]
    del remaining_indices[i]
  return indices_in_group
