import json
import time

import numpy as np

from guacamol.assess_distribution_learning import assess_distribution_learning
from guacamol.distribution_matching_generator import DistributionMatchingGenerator
from moses import get_all_metrics

from src.utils import DATA_PATH, WB_LOG_PATH, smiles_from_file


def calculate_all_sampling_metrics(smiles, dataset):
    dataset = dataset.lower()
    if dataset == "zinc":
        train_path = DATA_PATH / dataset / "train.txt"
        test_path = DATA_PATH / dataset / "test.txt"
    else:
        raise NotImplementedError()
    assert len(smiles) >= 10000

    results = dict(
        Guacamol_Metrics=calculate_guacamol_benchmark(train_path, smiles),
        Moses_Metrics=get_all_metrics(smiles, test=smiles_from_file(test_path)),
    )
    return results


class DummyGenerator(DistributionMatchingGenerator):
    def __init__(self, smiles):
        self.smiles = smiles

    def generate(self, number_samples: int):
        np.random.shuffle(self.smiles)
        return self.smiles[:number_samples]


def calculate_guacamol_benchmark(dataset_path, generated_smiles):
    model_generator = DummyGenerator(generated_smiles)
    json_path = WB_LOG_PATH / "guacamol_jsons" / (str(time.time()) + ".json")
    assess_distribution_learning(
        model_generator, chembl_training_file=dataset_path, json_output_file=json_path, benchmark_version="v2"
    )
    out_results = dict()
    with open(json_path) as json_file:
        guacamol_results = json.load(json_file)
        for subdict in guacamol_results["results"]:
            out_results[subdict["benchmark_name"]] = subdict["score"]
    return out_results
