from typing import List

from guacamol.distribution_matching_generator import DistributionMatchingGenerator
from guacamol.benchmark_suites import distribution_learning_benchmark_suite
from guacamol.assess_distribution_learning import _evaluate_distribution_learning_benchmarks

class GuacamolEvaluator(DistributionMatchingGenerator):
    """
    Generator that samples SMILES strings using an EDM model.
    """

    def __init__(self,) -> None:
        self.smiles = []

    def add_smiles(self, smiles):
        self.smiles.extend(smiles)
        # remove duplicates
        self.smiles = list(set(self.smiles))

    def get_smiles_count(self):
        return len(self.smiles)

    def clear_smiles(self):
        self.smiles = []

    def generate(self, number_samples: int) -> List[str]:
        return self.smiles[:number_samples]

    def evaluate(self, training_smiles_path, number_samples=10000):
        benchmarks = distribution_learning_benchmark_suite(chembl_file_path=training_smiles_path,
                                                        version_name='v2',
                                                        number_samples=number_samples)

        results = _evaluate_distribution_learning_benchmarks(model=self, benchmarks=benchmarks)
        return results
