import argparse
import glob
import os
import pickle

import hydra
from omegaconf import DictConfig

from trainer import build_model_trainer


@hydra.main(config_path="./configs")
def main(cfg: DictConfig):
    all_metrics = {}
    if cfg['common']['eval_all_checkpoints']:
        checkpoints = glob.glob(os.path.join(cfg['common']['checkpoints_dir'], "*.pt"))
        for checkpoint in checkpoints:
            cp = checkpoint.split("/")[-1]
            cfg['common']['checkpoint'] = cp
            epoch = int(cp.split("epoch_")[-1].split('_loss')[0])
            model_trainer = build_model_trainer(cfg)
            test_metrics = model_trainer.evaluate_new_test_metrics()
            all_metrics[epoch] = test_metrics
            output_file = os.path.join(os.path.join(cfg['common']['checkpoints_dir'], 'test_metrics.pkl'))
            with open(output_file, "wb") as f:
                pickle.dump(all_metrics, f)
    else:
        model_trainer = build_model_trainer(cfg)
        test_metrics = model_trainer.evaluate_new_test_metrics()
        epoch = cfg['common']['checkpoint'].split('epoch_')[-1].split('_')[0]
        if cfg['dataset']['similarity_metric'] == "deepfri":
            filename = f"{cfg['common']['experiment_name']}_thr_{cfg['dataset']['seed']}_test_epoch_{epoch}.pkl"
        else:
            filename = f"{cfg['common']['experiment_name']}_test_epoch_{epoch}.pkl"
        outdir = cfg['common']['output_path']
        output_file = os.path.join(os.path.join(outdir, filename))
        with open(output_file, "wb") as f:
            pickle.dump(test_metrics, f)



if __name__ == "__main__":
    main()
