import argparse

from data.load_data import SupportedDatasets
from data.load_multi_data import SupportedMultiDatasets
from train.utils.save_log import save_log
from train.recover_conf import recover_datasets, recover_multi_datasets, \
    recover_conf, recover_multi_conf, \
    eval_conf, eval_multi_conf


def evaluate(special_config, dataset_name: str, dataset_token: str, token: str, use_ff, include_rdkit, use_tqdm):
    if dataset_name in SupportedMultiDatasets.tolist():
        multi = True
    else:
        multi = False

    logs = [{'epoch': 0}]
    if multi:
        _, validate_set, test_set = recover_multi_datasets(special_config, dataset_name, dataset_token)
        model = recover_multi_conf(special_config, dataset_name, token)
        validate_loss_dict = eval_multi_conf(model, validate_set, dataset_token, use_tqdm=use_tqdm)
        test_loss_dict = eval_multi_conf(model, test_set, dataset_token, use_tqdm=use_tqdm)
    else:
        _, validate_set, test_set = recover_datasets(special_config, dataset_name, dataset_token)
        model = recover_conf(special_config, dataset_name, token)
        validate_loss_dict = eval_conf(model, validate_set, dataset_token, use_tqdm=use_tqdm)
        test_loss_dict = eval_conf(model, test_set, dataset_token, use_tqdm=use_tqdm)

    logs[-1].update({
        f'validate_{loss_key}_loss': float(loss_value) for loss_key, loss_value in validate_loss_dict.items()
    })
    logs[-1].update({
        f'test_{loss_key}_loss': float(loss_value) for loss_key, loss_value in test_loss_dict.items()
    })
    if use_ff:
        save_log(logs, directory=f'{dataset_name}-ff', tag=token)
    elif include_rdkit:
        save_log(logs, directory=f'{dataset_name}-eval', tag=token)


def evaluate_patch(dataset_name=SupportedMultiDatasets.GEOM_QM9_SMALL, dataset_token='phi-psi',
                   generate='rdkit', derive='newton', compare='equiv-trunc',
                   use_ff=True, include_rdkit=False, use_tqdm=False):
    try:
        evaluate(
            dataset_name=dataset_name,
            special_config={
                'GENERATE_TYPE': generate,
                'DERIVE_TYPE': derive,
                'COMPARE_TYPE': compare,
                'COMPARE_MIDDLE': derive != '',
                'FF': use_ff,
                'INCLUDE_RDKIT': include_rdkit,
            },
            token=f'{generate}-{derive}-{compare}',
            dataset_token=dataset_token,
            use_ff=use_ff,
            include_rdkit=include_rdkit,
            use_tqdm=use_tqdm
        )
    except FileNotFoundError:
        print('Not supported:')
        print(f'\tdataset_name: {dataset_name}')
        print(f'\tdataset_token: {dataset_token}')
        print(f'\tgenerate: {generate}')
        print(f'\tderive: {derive}')
        print(f'\tcompare: {compare}')


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='geom_qm9')  # geom_qm9, geom_drugs
parser.add_argument('--generate', type=str, default='rdkit')
parser.add_argument('--derive', type=str, default='newton')
parser.add_argument('--compare', type=str, default='equiv-trunc')
args = parser.parse_args()
dataset = args.dataset
generate = args.generate
derive = args.derive
compare = args.compare
evaluate_patch(dataset, generate=generate, derive=derive, compare=compare)
