
""" Functions for experimental results and outputs """

import json
import pandas as pd
from prettytable import PrettyTable

# standard results filenames (for all experiments)

labels_json = "labels.json"
results_json = "results.json"
baseline_json = "baseline.json"
confusion_ascii = "confusion.txt"
confusion_latex = "confusion.tex"

def generate_results(df_labels, df_annotations, experiment_folder=None):
    """
    Generate confusion matrix and results statistics for a single experiment.
    Writes out results, baseline and confusion matrix to a specified folder.
    Also returns the same as variables for testing purposes.

    Args:
        df_labels: DataFrame with generated labels
        df_annotations: Dataframe with target annotations
        experiment_folder: Path to folder for results files
        If experiment_folder is None, results will be returned without writing to file
    Returns:
        results: dictionary with overall ratio statistics (accuracy etc.)
        baseline: dictionary with random baseline statistics
        confusion_matrix: PrettyTable with confusion matrix
    """

    print("Combining generated labels with reference annotations")
    df_labels = pd.merge(df_labels, df_annotations, how='inner', on='instance_id')

    # confusion matrix

    print("Generating confusion matrix")

    true_positives = (df_labels["auto_test_label"] & df_labels["false_negative"]).sum()
    false_positives = (df_labels["auto_test_label"] & ~df_labels["false_negative"]).sum()
    false_negatives = (~df_labels["auto_test_label"] & df_labels["false_negative"]).sum()
    true_negatives = (~df_labels["auto_test_label"] & ~df_labels["false_negative"]).sum()

    total_predicted_positives = (df_labels["auto_test_label"]).sum()
    total_predicted_negatives = (~df_labels["auto_test_label"]).sum()
    total_actual_positives = (df_labels["false_negative"]).sum()
    total_actual_negatives = (~df_labels["false_negative"]).sum()
    overall_total = total_actual_positives + total_actual_negatives

    curation_errors_positive = (df_labels["curation_error"] & df_labels["auto_test_label"]).sum()
    curation_errors_negative = (df_labels["curation_error"] & ~df_labels["auto_test_label"]).sum()
    total_curation_errors = df_labels["curation_error"].sum()

    confusion_matrix = PrettyTable()
    confusion_matrix.field_names = ["", "Predicted Positives", "Predicted Negatives", "Total"]
    confusion_matrix.add_row(["Actual Positives", true_positives, false_negatives, total_actual_positives])
    confusion_matrix.add_row(["Actual Negatives", false_positives, true_negatives, total_actual_negatives])
    confusion_matrix.add_row(["Total", total_predicted_positives, total_predicted_negatives, overall_total])
    confusion_matrix.add_divider()
    confusion_matrix.add_row(["Curation Errors", curation_errors_positive, curation_errors_negative, total_curation_errors])

    # write confusion matrix to both ascii and latex

    if experiment_folder:
        print("Writing confusion matrix to file")
        with open(experiment_folder+"/"+confusion_ascii, 'w') as f_out:
            f_out.write(confusion_matrix.get_formatted_string(out_format="text"))
        with open(experiment_folder+"/"+confusion_latex, 'w') as f_out:
            f_out.write(confusion_matrix.get_formatted_string(out_format="latex"))
        format_tex_table(experiment_folder+"/"+confusion_latex)  # additional formatting

    # overall results / ratios

    print("Calculating results")

    results = {
        "accuracy": (true_positives + true_negatives) / overall_total,
        "recall": true_positives / total_actual_positives,
        "precision": true_positives / total_predicted_positives,
        "specificity": true_negatives / total_actual_negatives,
        "npv": true_negatives / total_predicted_negatives  # negative predictive value
    }
    results.update({"f1_score": 2 * (results["precision"] * results["recall"]) / (results["precision"] + results["recall"])})
    results.update({"balanced_accuracy": (results["recall"] + results["specificity"]) / 2})

    # random baseline (uniform)

    baseline = {
        "accuracy": 0.5,
        "recall": 0.5,
        "precision": total_actual_positives / overall_total,
        "specificity": 0.5,
        "npv": total_actual_negatives / overall_total
    }
    baseline.update({"f1_score": 2 * (baseline["precision"] * baseline["recall"]) / (baseline["precision"] + baseline["recall"])})
    baseline.update({"balanced_accuracy": (baseline["recall"] + baseline["specificity"]) / 2})

    # write results as json
    # these will be combined with other experiments in generate_tables

    if experiment_folder:
        print("Writing results to file\n")
        with open(experiment_folder+"/"+results_json, 'w') as f_out:
            json.dump(results, f_out)
        with open(experiment_folder+"/"+baseline_json, 'w') as f_out:
            json.dump(baseline, f_out)

    return results, baseline, confusion_matrix


def compile_tables(results_folder, reference_file):

    tables_folder = "tables"

    with open(reference_file, 'r') as f_ref:
        reference_results = [json.loads(line) for line in f_ref]

    for ref in reference_results:  # iterate experiments

        expno = ref["experiment_number"]
        print("Compiling results table for Experiment "+str(expno))

        semantic_folder = results_folder+"/experiment"+str(expno)+"_semantic"
        tokens_folder = results_folder+"/experiment"+str(expno)+"_tokens"

        # read the baseline, semantic and tokens-only results
        # the baseline can be taken from either configuration (should be the same)

        with open(semantic_folder+"/"+baseline_json) as f_baseline:
            rbase =  json.load(f_baseline)
        with open(semantic_folder+"/"+results_json) as f_results:
            rsem = json.load(f_results)
        with open(tokens_folder+"/"+results_json) as f_results:
            rtok =  json.load(f_results)

        # compile table for the current experiment
        # includes baseline, semantic results, tokens results

        results_table = PrettyTable()
        results_table.align["Field 1"] = "l"; dp = "."+str(ref["percentage_decimals"])+"%"
        results_table.field_names = ["", "Random", "Semantic", "Tokens Only"]
        results_table.add_row(["Accuracy", format(rbase["accuracy"],dp), format(rsem["accuracy"],dp), format(rtok["accuracy"],dp)])
        results_table.add_row(["Balanced", format(rbase["balanced_accuracy"],dp), format(rsem["balanced_accuracy"],dp), format(rtok["balanced_accuracy"],dp)])
        results_table.add_row(["Precision", format(rbase["precision"],dp), format(rsem["precision"],dp), format(rtok["precision"],dp)])
        results_table.add_row(["Recall", format(rbase["recall"],dp), format(rsem["recall"],dp), format(rtok["recall"],dp)])
        results_table.add_row(["F1 Score", format(rbase["f1_score"],dp), format(rsem["f1_score"],dp), format(rtok["f1_score"],dp)])
        results_table.add_row(["Specificity", format(rbase["specificity"],dp), format(rsem["specificity"],dp), format(rtok["specificity"],dp)])
        results_table.add_row(["NPV", format(rbase["npv"],dp), format(rsem["npv"],dp), format(rtok["npv"],dp)])

        # add base llm reference metrics (if any)

        if "base_results" in ref:
            r = ref["base_results"]
            results_table.add_column("Base LLM", [
                format(r["accuracy"], dp) if "accuracy" in r else "-",
                format(r["balanced_accuracy"], dp) if "balanced_accuracy" in r else "-",
                format(r["precision"], dp) if "precision" in r else "-",
                format(r["recall"], dp) if "recall" in r else "-",
                format(r["f1"], dp) if "f1" in r else "-",
                format(r["specificity"], dp) if "specificity" in r else "-",
                format(r["npv"], dp) if "npv" in r else "-"
            ])

        # add fine-tuned llm reference metrics (if any)

        if "finetuned_results" in ref:
            r = ref["finetuned_results"]
            results_table.add_column("Fine-Tuned LLM", [
                format(r["accuracy"], dp) if "accuracy" in r else "-",
                format(r["balanced_accuracy"], dp) if "balanced_accuracy" in r else "-",
                format(r["precision"], dp) if "precision" in r else "-",
                format(r["recall"], dp) if "recall" in r else "-",
                format(r["f1"], dp) if "f1" in r else "-",
                format(r["specificity"], dp) if "specificity" in r else "-",
                format(r["npv"], dp) if "npv" in r else "-"
            ])

        # output final table to both ascii and latex

        print("Writing table to file")
        with open(results_folder+"/"+tables_folder+"/experiment"+str(expno)+".txt", 'w') as f_out:
            f_out.write(results_table.get_formatted_string(out_format="text"))
        with open(results_folder+"/"+tables_folder+"/experiment"+str(expno)+".tex", 'w') as f_out:
            f_out.write(results_table.get_formatted_string(out_format="latex"))
        format_tex_table(results_folder+"/"+tables_folder+"/experiment"+str(expno)+".tex")  # additional formatting


def format_tex_table(tex_file):
    """ Additional formatting for latex results tables.
    Removes blank lines and replaces '%' with '\\%' """
    print("Formatting latex")
    with open(tex_file,'r+') as f_tex:
        lines = f_tex.readlines(); f_tex.seek(0)
        for line in lines:
            if not line.isspace():
                line = line.replace("%","\\%")
                f_tex.write(line)
