# TODO: make it so that if you use a metric, model, or dataset that isn't available when looking it up, it defaults the proper name to the 


import pandas as pd
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
import numpy as np
# from sklearn.metrics import roc_curve, auc, precision_recall_curve
from confidenceinterval import f1_score, roc_auc_score

import yaml, json
from copy import deepcopy



### START GLOBALS -------------------------------------------------------------------------

SHOULD_PERFORM_TRANSFERABILITY_TEST = False
SHOULD_AVERAGE_RESULTS_ACROSS_REFERENCE_MODEL = False

EXPERIMENT_FOLDER_NAME = "experiment_results_peer_review"
ANALYSIS_OUTPUT_FOLDER_NAME = "experiment_analyses"
RAW_RESULTS_FILE_NAME = "experiment_analyses/raw_results"

METRIC_CODENAMES_TO_TEST = {
    # "gemma2_2B": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "gemma2_9B": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "llama3_8B": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "falcon_7B":  ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    
    # "smollm_135M": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    "smollm_360M": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "smollm_1_7B": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "smollm2_135M": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "smollm2_360M": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
    # "smollm2_1_7B": ["telescope_perplexity", "binoculars_score", "perplexity", "lrr"],
}

DATASET_CODENAMES_TO_TEST = [
    # "detect_llm_text",
    # "ai_human",
    # "hc3",
    # "hc3_plus",
    # "esl_gpt4o",
    # "m4_multilingual",
    # "m4_monolingual",
    
    # "ghostbusters_essay_gpt",
    # "ghostbusters_news_gpt",
    # "ghostbusters_creative_gpt",
    # "ghostbusters_essay_gpt4o",
    # "ghostbusters_creative_gpt4o",
    # "ghostbusters_news_claude",
    # "ghostbusters_creative_claude",
    # "ghostbusters_essay_claude",
    # "ghostbusters_essay_deepseek",
    # "ghostbusters_creative_deepseek",
    
    # "ghostbusters_creative_gpt4o_adversarial_prompt",
    # "ghostbusters_creative_gpt4o_low_temperature",
    # "ghostbusters_creative_gpt4o_high_temperature",
    "ghostbusters_essay_gpt4o_adversarial_prompt",
    # "ghostbusters_essay_gpt4o_adversarial_prompt2",
    # "ghostbusters_essay_gpt4o_low_temperature",
    # "ghostbusters_essay_gpt4o_high_temperature",
    # "m4_english_wikipedia_chatgpt"
    
]

### END GLOBALS -------------------------------------------------------------------------



# a dictionary that maps a dataset's codename (for instance ghostbusters_essay_gpt) to a presentable, paper-ready name (for instance GB Essay ChatGPT)
DATASET_CODENAME_TO_DATASET_DISPLAYNAME = yaml.safe_load(open("config.yaml"))["dataset_codenames_to_dataset_displaynames"]

# a dictionary that maps a model's codename (for instance smollm2_360M) to a presentable, paper-ready name (for instance SmolLM2 360M)
MODEL_CODENAME_TO_MODEL_DISPLAYNAME = yaml.safe_load(open("config.yaml"))["model_codenames_to_model_displaynames"]




def generate_latex_table_from_data(result_dict, dataset_codenames_to_show, metric_codenames_to_show, score_name, score_type):
    
    for dataset_codename in dataset_codenames_to_show:
        print("\midrule\n\multirow{10}{*}" + r"{" + f"{DATASET_CODENAME_TO_DATASET_DISPLAYNAME[dataset_codename]}" + r"}")
        
        for model_codename, metric_codenames_from_experiment in metric_codenames_to_show.items():
            
            model_displayname = MODEL_CODENAME_TO_MODEL_DISPLAYNAME[model_codename]
            
            # figure out which metric to bold to highlight best performance if this is a raw score
            if score_type == float:
                best_metric_codenames: list[str] = []
                best_metric_score = -np.inf
                for metric_codename in metric_codenames_from_experiment:
                    score = result_dict[(dataset_codename, model_displayname, metric_codename)][score_name]
                    
                    if score > best_metric_score:
                        best_metric_codenames = [metric_codename,]
                        best_metric_score = score 
                        
                    if score == best_metric_score:
                        best_metric_codenames.append(metric_codename)
                
                
                
            stuff_to_print = ""
            for metric_codename in metric_codenames_from_experiment:
                score = result_dict[(dataset_codename, model_displayname, metric_codename)][score_name]
                
                if score_type == tuple:
                    stuff_to_print += f"& ({score[0]:.5f}, {score[1]:.5f}) "
                
                if score_type == float:
                    if metric_codename in best_metric_codenames:
                        stuff_to_print += f"& \\textbf{{{score:.5f}}} "
                    
                    else:
                        stuff_to_print += f"& {score:.5f} "
                        
                    
            print(f"& {model_displayname} {stuff_to_print} \\\\")
            
        print()
        



def generate_latex_table_from_data_averaged_across_reference_models(result_dict, dataset_codenames_to_show, metric_codenames_to_show, score_name, score_type): 
   
    # generate latex code for results averaged across reference models (AUROC)
    for dataset_codename in dataset_codenames_to_show:
        
        if score_type == float:
            total_scores = {metric_codename: 0 for metric_codenames_from_experiment in metric_codenames_to_show.values() for metric_codename in metric_codenames_from_experiment}
        
        elif score_type == tuple:
            total_scores = {metric_codename: (0, 0) for metric_codenames_from_experiment in metric_codenames_to_show.values() for metric_codename in metric_codenames_from_experiment} 
        
        else:
            raise Exception("score_type not correct")
        
        
        number_of_each_metric = {metric_codename: 0 for metric_codenames_from_experiment in metric_codenames_to_show.values() for metric_codename in metric_codenames_from_experiment}

        # compute the total scores of each model-metric combination
        for model_codename, metric_codenames_from_experiment in metric_codenames_to_show.items():
                        
            model_displayname = MODEL_CODENAME_TO_MODEL_DISPLAYNAME[model_codename]

            for metric_codename in metric_codenames_from_experiment:
                
                score = result_dict[(dataset_codename, model_displayname, metric_codename)][score_name]
                
                if score_type == float:
                    total_scores[metric_codename] += score
                    number_of_each_metric[metric_codename] += 1
                
                if score_type == tuple:
                    total_scores[metric_codename] = (total_scores[metric_codename][0] + score[0], total_scores[metric_codename][1] + score[1])
                    number_of_each_metric[metric_codename] += 1
        
        
        # figure out which metric to bold to highlight best performance if this is a raw score
        if score_type == float:
            best_metric_codenames: list[str] = []
            best_metric_score = -np.inf
            for metric_codename in metric_codenames_from_experiment:
                score = total_scores[metric_codename]
                
                if score > best_metric_score:
                    best_metric_codenames = [metric_codename,]
                    best_metric_score = score 
                    
                if score == best_metric_score:
                    best_metric_codenames.append(metric_codename)
                        
                        
        stuff_to_print = f"{DATASET_CODENAME_TO_DATASET_DISPLAYNAME[dataset_codename]}"
        for metric_codename, total_score in total_scores.items():
            
            if score_type == float:
                average_score = total_score/ (number_of_each_metric[metric_codename])

                if metric_codename in best_metric_codenames:
                    stuff_to_print += f"& \\textbf{{{average_score:.5f}}}"
                else:
                    stuff_to_print += f"& {average_score:.5f}"
            
            
            
            if score_type == tuple:
                average_score = (total_score[0]/ number_of_each_metric[metric_codename], total_score[1]/ number_of_each_metric[metric_codename])
                stuff_to_print += f"& ({average_score[0]:.5f}, {average_score[1]:.5f}) "
             
        print(stuff_to_print + "& \\\\")






def create_logistic_regression_classifier_from_metric(metric, labels) -> LogisticRegression:
    """
    Uses a logistic regression classifier to determine the classification threshold for a single metric
    
    This should be equivalent to finding the decision threshold that maximizes accuracy, and a bonus is that
    the logistic regression creates a probability distribution to directly quantify how sure the classifier is
    
    There should only be one metric passed in
    """
    clf: Pipeline = make_pipeline(StandardScaler(), LogisticRegression())
    # clf: Pipeline = make_pipeline(StandardScaler(), SVC())
    clf.fit(metric, labels)
    return clf



def generate_roc_curve_from_metric():
    
    result_dict = {}
    for model_codename, metric_codenames_from_experiment in METRIC_CODENAMES_TO_TEST.items():
        
        model_displayname = MODEL_CODENAME_TO_MODEL_DISPLAYNAME[model_codename]
        
        for dataset_codename in DATASET_CODENAMES_TO_TEST:
            for metric_codename in metric_codenames_from_experiment:
                result_dict[(dataset_codename, model_displayname, metric_codename)] = dict()
                
                result_dict[(dataset_codename, model_displayname, metric_codename)]["AUROC Confidence Interval"] = np.nan
                result_dict[(dataset_codename, model_displayname, metric_codename)]["F1 Score Confidence Interval"] = np.nan
                result_dict[(dataset_codename, model_displayname, metric_codename)]["Transfered F1 Score Confidence Interval"] = np.nan
                result_dict[(dataset_codename, model_displayname, metric_codename)]["AUROC"] = np.nan
                result_dict[(dataset_codename, model_displayname, metric_codename)]["F1 Score"] = np.nan
                result_dict[(dataset_codename, model_displayname, metric_codename)]["Transfered F1 Score"] = np.nan
                    
    
    for model_codename, metric_codenames_from_experiment in METRIC_CODENAMES_TO_TEST.items():
        
        model_displayname = MODEL_CODENAME_TO_MODEL_DISPLAYNAME[model_codename]
        
        for test_dataset_codename in DATASET_CODENAMES_TO_TEST:
            df = pd.read_csv(f"{EXPERIMENT_FOLDER_NAME}/{model_codename}_{test_dataset_codename}_dataset/raw_data.csv")
            df = df.replace([np.inf, -np.inf], np.nan)
            df = df.dropna(subset=metric_codenames_from_experiment)
            y_labels = df["y_labels"].astype(bool)
            
            
            for metric_codename in metric_codenames_from_experiment:
                
                if SHOULD_PERFORM_TRANSFERABILITY_TEST:
                    
                    train_df_list = []
                    for train_dataset_index, train_dataset_codename in enumerate(DATASET_CODENAMES_TO_TEST):
                        if (train_dataset_codename == test_dataset_codename): # don't test on the same dataset you train on
                            continue
                    
                        train_df = pd.read_csv(f"{EXPERIMENT_FOLDER_NAME}/{model_codename}_{train_dataset_codename}_dataset/raw_data.csv")
                        
                        train_df = train_df.replace([np.inf, -np.inf], np.nan)
                        train_df = train_df.dropna(subset=metric_codenames_from_experiment)

                        
                        # some datasets are much larger than others, so make sure to only take a few samples from each dataset, so that the number of samples from each dataset is roughly equivalent
                        train_df = train_df.head(2000)
                        train_df_list.append(train_df)
                        
                    combined_train_df = pd.concat(train_df_list)
                    transfered_classifier = create_logistic_regression_classifier_from_metric(combined_train_df[[metric_codename,]], combined_train_df["y_labels"])
                    predicted_labels_transfered_classifier = transfered_classifier.predict(df[[metric_codename,]])
                    best_f1score_transfered_classifier, f1_confidence_interval_transfered_classifier = f1_score(y_labels, predicted_labels_transfered_classifier)
                
                else:
                    best_f1score_transfered_classifier, f1_confidence_interval_transfered_classifier = 0, (0, 0)
                
                
                
                metric_scores = df[[metric_codename,]]
                fixed_scores_for_rocauc = deepcopy(df[metric_codename])
                if metric_codename == "binoculars_score" or metric_codename == "perplexity": 
                    fixed_scores_for_rocauc = -fixed_scores_for_rocauc
                
                
                roc_auc, roc_auc_confidence_interval = roc_auc_score(y_labels, fixed_scores_for_rocauc)
                
                classifier = create_logistic_regression_classifier_from_metric(metric_scores, y_labels)
                predicted_labels = classifier.predict(df[[metric_codename,]])
                best_f1score, f1_confidence_interval = f1_score(y_labels, predicted_labels)

                
                roc_auc_confidence_interval = (float(roc_auc_confidence_interval[0]), float(roc_auc_confidence_interval[1]))
                f1_confidence_interval =  (float(f1_confidence_interval[0]), float(f1_confidence_interval[1]))
                f1_confidence_interval_transfered_classifier =  (float(f1_confidence_interval_transfered_classifier[0]), float(f1_confidence_interval_transfered_classifier[1]))

                result_dict[(test_dataset_codename, model_displayname, metric_codename)]["F1 Score Confidence Interval"] = f1_confidence_interval
                result_dict[(test_dataset_codename, model_displayname, metric_codename)]["F1 Score"] = float(best_f1score)
                result_dict[(test_dataset_codename, model_displayname, metric_codename)]["Transfered F1 Score Confidence Interval"] = f1_confidence_interval_transfered_classifier
                result_dict[(test_dataset_codename, model_displayname, metric_codename)]["Transfered F1 Score"] = float(best_f1score_transfered_classifier)
                result_dict[(test_dataset_codename, model_displayname, metric_codename)]["AUROC Confidence Interval"] = roc_auc_confidence_interval
                result_dict[(test_dataset_codename, model_displayname, metric_codename)]["AUROC"] = float(roc_auc)

                # print(f"MODEL: {model_displayname}, DATASET: {test_dataset_codename}, METRIC: {metric_codename}, AUROC: {roc_auc}, {roc_auc_confidence_interval}, {best_f1score}, {f1_confidence_interval}")
                print(f"MODEL: {model_displayname}, DATASET: {test_dataset_codename}, METRIC: {metric_codename}, AUROC: {roc_auc}")
               

    # generate mostly paper ready latex code from the data
    print("\n\n")
   

    if not SHOULD_PERFORM_TRANSFERABILITY_TEST and not SHOULD_AVERAGE_RESULTS_ACROSS_REFERENCE_MODEL:
        generate_latex_table_from_data(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "AUROC", float)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
        generate_latex_table_from_data(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "AUROC Confidence Interval", tuple)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
    
    if not SHOULD_PERFORM_TRANSFERABILITY_TEST and SHOULD_AVERAGE_RESULTS_ACROSS_REFERENCE_MODEL: 
        generate_latex_table_from_data_averaged_across_reference_models(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "AUROC", float)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
        generate_latex_table_from_data_averaged_across_reference_models(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "AUROC Confidence Interval", tuple)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
    
    if SHOULD_PERFORM_TRANSFERABILITY_TEST and not SHOULD_AVERAGE_RESULTS_ACROSS_REFERENCE_MODEL:
        generate_latex_table_from_data(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "Transfered F1 Score Confidence Interval", tuple)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
        generate_latex_table_from_data(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "Transfered F1 Score", float)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
    
    if SHOULD_AVERAGE_RESULTS_ACROSS_REFERENCE_MODEL and SHOULD_AVERAGE_RESULTS_ACROSS_REFERENCE_MODEL:
        generate_latex_table_from_data_averaged_across_reference_models(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "Transfered F1 Score Confidence Interval", tuple)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
        generate_latex_table_from_data_averaged_across_reference_models(result_dict, DATASET_CODENAMES_TO_TEST, METRIC_CODENAMES_TO_TEST, "Transfered F1 Score", float)
        print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
    
    
    
    # Save to a json so that we can use these results easily in other things
    # Make sure all of the data gets into the right format since without this, we can't convert the data to a json
    corrected_result_dict = {}
    for model_codename, metric_codenames_from_experiment in METRIC_CODENAMES_TO_TEST.items():
        
        model_displayname = MODEL_CODENAME_TO_MODEL_DISPLAYNAME[model_codename]
        
        corrected_result_dict[model_displayname] = {}
        for metric_codename in metric_codenames_from_experiment:
            corrected_result_dict[model_displayname][metric_codename] = {}
            for dataset_codename in DATASET_CODENAMES_TO_TEST:
                corrected_result_dict[model_displayname][metric_codename][dataset_codename] = result_dict[(dataset_codename, model_displayname, metric_codename)]
    
    with open("results_data.json", "w") as file:
        json.dump(corrected_result_dict, file)
    
    
    
    

if __name__ == "__main__":
    generate_roc_curve_from_metric()
