import pandas as pd
from tqdm import tqdm
import numpy as np
import os

from typing import Mapping
from transformers import BitsAndBytesConfig
from telescope.utils import get_hugging_face_auth_token
from telescope.telescope import Telescope




### GLOBALS -------------------------------------------------------------------------

BITS_AND_BYTES_QUANTIZATION_CONFIG = BitsAndBytesConfig(load_in_8bit=True)

MAX_NUMBER_OF_SAMPLES = 10_000
MINIMUM_NUMBER_OF_WORDS_IN_SAMPLE = 100
MAXIMUM_NUMBER_OF_WORDS_IN_SAMPLE = 5000
EXPERIMENT_FOLDER = "experiment_results_peer_review"


# model_codename: (PERFORMER_MODEL_HUGGINGFACE_REPOSITORY, OBSERVER_MODEL_HUGGINGFACE_REPOSITORY)
MODEL_PERFORMER_OBSERVER_PAIRS_TO_TEST = {
    #"smollm_135M": ("HuggingFaceTB/SmolLM-135M-Instruct", "HuggingFaceTB/SmolLM-135M"),
    # "smollm2_135M": ("HuggingFaceTB/SmolLM2-135M-Instruct", "HuggingFaceTB/SmolLM2-135M"),
    "smollm_360M": ("HuggingFaceTB/SmolLM-360M-Instruct", "HuggingFaceTB/SmolLM-360M"),
    # "smollm2_360M": ("HuggingFaceTB/SmolLM2-360M-Instruct", "HuggingFaceTB/SmolLM2-360M"),
    #"smollm_1_7B": ("HuggingFaceTB/SmolLM-1.7B-Instruct", "HuggingFaceTB/SmolLM-1.7B"),
    #"smollm2_1_7B": ("HuggingFaceTB/SmolLM2-1.7B-Instruct", "HuggingFaceTB/SmolLM2-1.7B"),
    # "gemma2_2B": ("google/gemma-2-2b-it", "google/gemma-2-2b"),
    #"llama3_8B": ("meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-8B"),
    #"falcon_7B": ("tiiuae/falcon-7b-instruct", "tiiuae/falcon-7b"),
    #"gemma2_9B": ("google/gemma-2-9b-it", "google/gemma-2-9b"),

}


# "DetectLLMText_Dataset.csv" 
# "AI_Human_Dataset.csv"
# "ESL_GPT4o_Dataset.csv"
# "Ghostbusters_Dataset.csv"
# "HC3_Dataset.csv"
# "HC3_Plus_Dataset.csv"
# "M4_English_Wikipedia_ChatGPT_Dataset.csv"
# "M4_Russian_ChatGPT_Dataset.csv"
# if a dataset is in the "datasets" folder, then you can input it here to test on that dataset
DATASET_FILE = "Ghostbusters_Essay_GPT4o_Adversarial_Prompt2_Dataset.csv"
DATASET_FOLDER = "datasets"

### GLOBALS -------------------------------------------------------------------------





# INFO: Some Example Experiment Names:
# EXPERIMENT_NAME = "smollm_360M_ai_human_dataset"
# Generate Experiment Names to Know What Folders To Save the Experiments To
experiment_name_list = []
for model_codename in MODEL_PERFORMER_OBSERVER_PAIRS_TO_TEST.keys(): 
    experiment_name = model_codename + "_"
    experiment_name += DATASET_FILE.lower()[:-4]    # remove .csv
    experiment_name_list.append(experiment_name)

print(f"experiment name list: {experiment_name_list}")
print()




def compute_accuracy_based_on_threshold(y_labels: np.ndarray, y_scores: np.ndarray, detection_threshold: float):
    number_correct = 0
    total = 0
        
    for i in range(len(y_scores)):
        total += 1      
        if y_labels[i] and y_scores[i] > detection_threshold: 
            number_correct+=1
            
        if not y_labels[i] and y_scores[i] < detection_threshold: 
            number_correct+=1
        
    return number_correct/total
    
    
def save_experiment(
        y_labels: np.ndarray, 
        metrics_for_each_model: Mapping[str, Mapping[str, list]], 
        original_texts: list[str],
        filepath: str, 
    ):

    # Save raw data for each model into their respective experiments
    for index, model_name in enumerate(MODEL_PERFORMER_OBSERVER_PAIRS_TO_TEST.keys()):
        save_directory = f'{filepath}/{experiment_name_list[index]}'
        print(f"Attempting to save to directory: {save_directory}")  
        if not os.path.exists(save_directory):
            os.makedirs(save_directory, exist_ok=True)
        
        df = pd.DataFrame({"y_labels": y_labels, "original_texts": original_texts, **(metrics_for_each_model[model_name])})
        df.to_csv(f'{save_directory}/raw_data.csv', index=False)        



def main():
    hugging_face_auth_token = get_hugging_face_auth_token("hugging_face_auth_token.txt")

    dataset = pd.read_csv(f"{DATASET_FOLDER}/{DATASET_FILE}").sample(frac=1, random_state=42).reset_index(drop=True)

    text_dataset = dataset["text"]
    is_ai_generated_dataset = dataset["generated"]
    
    # Initialize detectors based on configuration
    text_detectors: dict[str, Telescope] = {}
    metrics_for_each_model = {}
    
    # Initialize Telescope Detectors
    for text_detector_name, (performer_model, observer_model) in MODEL_PERFORMER_OBSERVER_PAIRS_TO_TEST.items():
        text_detector = Telescope(
            observer_model,
            performer_model, 
            hugging_face_auth_token, 
            BITS_AND_BYTES_QUANTIZATION_CONFIG
        )
        
        text_detectors[text_detector_name] = text_detector
        metrics_for_each_model[text_detector_name] = {}    
    
    
    # Compute All Scores
    labels = []
    original_texts = []
    number_of_samples_examined = 0
    for index, (text_data, is_ai_generated) in tqdm(enumerate(zip(text_dataset, is_ai_generated_dataset)), total=len(text_dataset)):
        if number_of_samples_examined >= MAX_NUMBER_OF_SAMPLES: continue
        if (type(text_data) != str): continue
        if (len(text_data.split(" ")) < MINIMUM_NUMBER_OF_WORDS_IN_SAMPLE): continue
        if (len(text_data.split(" ")) > MAXIMUM_NUMBER_OF_WORDS_IN_SAMPLE): continue
        

        number_of_samples_examined += 1
        
        # Run Telescope metrics for each model pair
        for text_detector_name, text_detector in text_detectors.items():
            metrics_dict = text_detector.compute_all_metrics(text_data)

            for metric_name, metric_value in metrics_dict.items():
                if metric_name not in metrics_for_each_model[text_detector_name].keys():
                    metrics_for_each_model[text_detector_name][metric_name] = []
                
                if isinstance(metric_value, (np.ndarray, list)):
                    metrics_for_each_model[text_detector_name][metric_name].append(metric_value)
                else:
                    metrics_for_each_model[text_detector_name][metric_name].append(float(metric_value))

        labels.append(is_ai_generated)
        original_texts.append(text_data)
        
        
        if index > 1 and index % 50 == 0: 
            save_experiment(np.array(labels), metrics_for_each_model, original_texts, EXPERIMENT_FOLDER)
        
    save_experiment(np.array(labels), metrics_for_each_model, original_texts, EXPERIMENT_FOLDER)
    
    
if __name__ == "__main__":
    main()