from data.qm9.load_qm9 import load_qm9
from data.load_data import SupportedDatasets
from data.load_multi_data import SupportedMultiDatasets
from visualize.vis_derive import vis_real_rdkit_with_mols, vis_2d_graph, \
    vis_derive_with_mols, vis_derive_multi_with_mols


def visualize_single(list_mol: list, generate='rdkit', derive='newton', compare='equiv-trunc'):
    vis_derive_with_mols(
        dataset_name=SupportedDatasets.QM9,
        list_mol=list_mol,
        special_config={
            'GENERATE_TYPE': generate,
            'DERIVE_TYPE': derive,
            'COMPARE_TYPE': compare,
        },
        use_cuda=False,
        final_only=True
    )
    # vis_derive_multi_with_mols(
    #     dataset_name=SupportedMultiDatasets.GEOM_QM9,
    #     list_mol=mols,
    #     special_config={
    #         'GENERATE_TYPE': generate,
    #         'DERIVE_TYPE': derive,
    #         'COMPARE_TYPE': compare,
    #     },
    #     use_cuda=False
    # )


if __name__ == '__main__':
    # indices = [97, 132, 5336, 55681, 123456, 123856]
    indices = [20801, 114165, 128129, 5182]
    mols, _ = load_qm9(max_num=max(indices) + 1)
    mols = [mols[idx] for idx in indices]
    vis_2d_graph(mols)
    vis_real_rdkit_with_mols(mols)
    visualize_single(mols)
    visualize_single(mols, compare='adj3')
    visualize_single(mols, compare='kabsch')
