import argparse
import glob2
from downstream.utils import evaluation
import json
import torch


def ensemble(args):
    input_folders = args.input_folders.split(",")
    all_predictions = []
    for input_folder in input_folders:
        files = glob2.glob(input_folder + "/" + args.dataset + "_prediction*.json")
        for file in files:
            with open(file) as f:
                data = json.load(f)
                prediction = torch.tensor(data["prediction"])
                label = torch.tensor(data["label"])
                score, _, _ = evaluation(prediction, label, args.metrics)
                print(file, score)
                all_predictions.append(prediction)
                mean_predictions = torch.mean(torch.stack(all_predictions), dim=0)
                ensemble_score, _, _ = evaluation(mean_predictions, label, args.metrics)
                print("Ensemble score", args.dataset, ensemble_score)
    print("Final ensemble score", args.dataset, ensemble_score)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TDC Admet ensemble')
    parser.add_argument('--dataset', default='DILI', type=str,
                        help='Name of the ADMET dataset')
    parser.add_argument("--metrics", default="auroc", type=str,
                        help='Report metrics')
    parser.add_argument("--input_folders", default="./", type=str,
                        help='Input folders, separated by commas')

    args = parser.parse_args()
    ensemble(args)
