import argparse
import os

from src.metrics import calculate_all_sampling_metrics
from src.model.load_utils import load_model_from_id
from src.model.flow_vae import FlowMAGNet
from src.utils import DATA_PATH, ROOT_DIR, smiles_from_file, WB_COLLECTION


def evaluate_magnet(dataset, collection, model_path, num_samples, mode, reconstruction_file, model_class):
    model = load_model_from_id(collection, model_path, dataset=dataset, model_class=model_class)
    if mode == "reconstruction":
        input_smiles = smiles_from_file(DATA_PATH / dataset.lower() / reconstruction_file)
        output_smiles = model.reconstruct_from_smiles(input_smiles[:num_samples])
    else:
        assert mode == "sampling"
        output_smiles = model.sample_molecules(num_samples)

    results = calculate_all_sampling_metrics(output_smiles, "zinc")
    print(results)
    results["generated_smiles"] = output_smiles
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="ZINC")
    parser.add_argument("--mode", default="reconstruction")
    parser.add_argument("--reconstruction_file", default="val.txt")
    parser.add_argument("--collection", default=WB_COLLECTION)
    parser.add_argument("--model_path", default="2op8w2pw")
    parser.add_argument("--num_samples", type=int, default=100)
    parser.add_argument("--model_class", default=FlowMAGNet)
    args = parser.parse_args()
    os.chdir(ROOT_DIR)
    results = evaluate_magnet(**vars(args))
