import os

def get_best_result(eval_report_directory):
    reports = {}

    for eval_report_filename in os.listdir(eval_report_directory):
        if not eval_report_filename[-3:] == 'txt':
            continue
        model_name = '.'.join(eval_report_filename.split('.')[:-3])
        epoch_ct = eval_report_filename.split('.')[-3]
        unary_stats = {}
        binary_stats = {}

        # if int(epoch_ct) > 3:
        #     continue

        if not model_name in reports:
            reports[model_name] = {"best_unary": (-1, -1),
                                   "best_binary": (-1, -1),
                                   "best_unary_f1": (-1, -1),
                                   "best_binary_f1": (-1, -1),
                                   }

        report_path = os.path.join(eval_report_directory, eval_report_filename)
        phase = None

        for line in open(report_path, 'r'):
            if "unary" in line:
                phase = "unary"

            if "binary" in line:
                phase = "binary"

            if "," in line:
                keyword, dist, precision, recall, f1 = line.split(',')
                if phase == "unary":
                    stats = unary_stats
                else:
                    stats = binary_stats

                if not keyword in stats:
                    stats[keyword] = {}
                    stats[keyword]['dist'] = float(dist)
                    stats[keyword]['precision'] = float(precision)
                    stats[keyword]['recall'] = float(recall)
                    stats[keyword]['f1'] = float(f1)

            if "Accu" in line:
                accuracy = line.split(":")[1]
                accuracy = float(accuracy.strip())
                best_unary = reports[model_name]["best_unary"][0]
                best_binary = reports[model_name]["best_binary"][0]
                best_unary_f1 = reports[model_name]["best_unary_f1"][0]
                best_binary_f1 = reports[model_name]["best_binary_f1"][0]

                aggr_f1_ls = []
                for keyword, st in stats.items():
                    aggr_f1_ls.append(st['dist'] * st['f1'])
                aggr_f1 = sum(aggr_f1_ls)

                if phase == "unary":
                    if accuracy > best_unary:
                        reports[model_name]["best_unary"] = (accuracy, epoch_ct)
                    if aggr_f1 > best_unary_f1:
                        reports[model_name]["best_unary_f1"] = (aggr_f1, epoch_ct)

                elif phase == "binary":
                    if accuracy > best_binary:
                        reports[model_name]["best_binary"] = (accuracy, epoch_ct)
                    if aggr_f1 > best_binary_f1:
                        reports[model_name]["best_binary_f1"] = (aggr_f1, epoch_ct)
    return reports

def get_best_model(report):
    best_models = {}
    for model_name, model_stats in report.items():
        for key, (val, epoch_ct) in model_stats.items():
            if not key in best_models:
                best_models[key] = (model_name, epoch_ct, val)
            if val > best_models[key][2]:
                best_models[key] = (model_name, epoch_ct, val)
    return best_models

if __name__ == "__main__":
    eval_report_directory = ""

    reports = get_best_result(eval_report_directory)
    best_models = get_best_model(reports)

    print(reports)
    print(best_models)

    print("end")

