import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM, get_peft_model, PeftConfig
from typing import List, Tuple, Dict
import numpy as np
import argparse
from pathlib import Path
import pandas as pd
import pickle
from tqdm import tqdm
import os
import logging
import sys


def get_assistant_marker_token_id(args):
    # keep this in sync with redflag_llm/sft_runner.py
    PHI3_PLACEHOLDER1 = 32002  # '<|placeholder1|>'
    PHI3_RESPONSE_KEYWORD = "<|assistant|>\n"
    PHI3_RESPONSE_IDX = [32001]

    LLAMA3_2_PLACEHOLDER = 128255
    LLAMA3_2_RESPONSE_KEYWORD = "<|start_header_id|>assistant<|end_header_id|>"
    LLAMA3_2_RESPONSE_IDX = [128006, 78191, 128007]

    MISTRALV3_PLACEHOLDER = 34
    MISTRALV3_RESPONSE_KEYWORD = "[/INST]"
    MISTRALV3_RESPONSE_IDX = [4]
    
    if "phi" in args.model_name.lower():
        model_rf_token_id = PHI3_PLACEHOLDER1
        model_response_keyword = PHI3_RESPONSE_KEYWORD 
        model_response_idx = PHI3_RESPONSE_IDX
        tie_word_embeddings = False
    elif "llama" in args.model_name.lower(): 
        model_rf_token_id = LLAMA3_2_PLACEHOLDER
        model_response_keyword = LLAMA3_2_RESPONSE_KEYWORD 
        model_response_idx = LLAMA3_2_RESPONSE_IDX
        tie_word_embeddings = True
    elif "mistral" in args.model_name.lower(): 
        model_rf_token_id = MISTRALV3_PLACEHOLDER
        model_response_keyword = MISTRALV3_RESPONSE_KEYWORD 
        model_response_idx = MISTRALV3_RESPONSE_IDX
        tie_word_embeddings = False
    
    return model_response_idx

def get_rf_token_probs(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    message: List[Dict],
    args: argparse.Namespace,
) -> float:
    """
    Calculate the maximum probability of a specific token in the generated text.
    
    Args:
        model: The language model to use
        tokenizer: The tokenizer corresponding to the model
        message: The input prompt/question
        args: Command line arguments
        
    Returns:
        float: Maximum probability of the target token
    """
    # Encode the input text
    if args.chat_mode == 'on': # default chat mode
        # different parameters for direct and prefilled messages in the apply_chat_template function
        if len(message) == 2: # prefilled messages
            inputs = tokenizer.apply_chat_template(
                message, 
                return_tensors="pt", 
                add_generation_prompt=False,
                continue_final_message=True
                ).cuda()
        else: # direct messages
            inputs = tokenizer.apply_chat_template(message, return_tensors="pt", add_generation_prompt=True).cuda()
    else:
        inputs = tokenizer([" ".join([d["content"] for d in message])], return_tensors="pt")["input_ids"].cuda()
    
    # Generate text and get logits
    outputs = model.generate(
        inputs,
        max_length=args.max_length,
        num_return_sequences=1,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=True,
    )
    
    # Get the scores for each generated token
    # output.scores only contains scores for the newly generated tokens
    scores = torch.stack(outputs.scores, dim=1)
    
    # get scores for the prefilling if using prefilled messages
    if args.chat_mode == 'on' and len(message) == 2:
        assistant_marker_token_id = get_assistant_marker_token_id(args)
        len_assistant_marker = len(assistant_marker_token_id)
        # do a forward pass to get the logits
        out = model(inputs)
        logits = out.logits
        
        # Get the full sequence as a list
        full_sequence = inputs[0].tolist()
        
        # Find the position of the assistant marker
        for i in range(len(full_sequence) - len_assistant_marker):
            if full_sequence[i:i + len_assistant_marker] == assistant_marker_token_id:
                prefilling_start_pos = i + len_assistant_marker
                break
        # get the logits for the prefilling
        prefilling_logits = logits[:, prefilling_start_pos:, :]
        # add the prefilling scores to the beginning of the scores only if using prefilled messages
        scores = torch.cat([prefilling_logits, scores], dim=1)
        
    
    # Convert scores to probabilities using softmax
    probs = torch.softmax(scores, dim=-1)
    
    # Convert scores to log probabilities using log_softmax
    target_log_probs = torch.log_softmax(scores, dim=-1)
    
    # Get probabilities and log probabilities for the target token
    target_probs = probs[:, :, args.target_token_id].squeeze()
    target_log_probs = target_log_probs[:, :, args.target_token_id].squeeze()
    
    # Find position of the end of turn (<eot>) token in the generated sequence
    generated_ids = outputs.sequences[0][inputs.size(1):]  # Remove input tokens
    
    eot_pos = (generated_ids == tokenizer.eos_token_id).nonzero()
    first_eot_pos = eot_pos[0].item() if len(eot_pos) > 0 else None

    # Find the position of the RF token in the generated sequence
    rf_token_pos = (generated_ids == args.target_token_id).nonzero()
    first_rf_pos = rf_token_pos[0].item() if len(rf_token_pos) > 0 else None

    # Determine the cutoff position
    if first_eot_pos is not None and first_rf_pos is not None:
        cutoff_pos = min(first_eot_pos, first_rf_pos)  # Cut at the earliest occurrence
    elif first_eot_pos is not None:
        cutoff_pos = first_eot_pos
    elif first_rf_pos is not None:
        cutoff_pos = first_rf_pos
    else:
        cutoff_pos = None  # Neither token was found

    # Slice target_probs up to the cutoff position, if applicable
    if cutoff_pos is not None:
        target_probs = target_probs[:cutoff_pos + 1]  # Include the token at cutoff_pos
        target_log_probs = target_log_probs[:cutoff_pos + 1]

    # Return the maximum probability
    rf_token_max_prob = float(torch.max(target_probs).item())
    rf_token_max_log_prob = float(torch.max(target_log_probs).item())
    rf_token_mean_prob = float(torch.mean(target_probs).item())
    rf_token_mean_log_prob = float(torch.mean(target_log_probs).item())
        
    return outputs, rf_token_max_prob, rf_token_max_log_prob, rf_token_mean_prob, rf_token_mean_log_prob

def evaluate_dataset(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    dataset: List[Dict],
    args: argparse.Namespace
) -> List[Tuple[str, float]]:
    """
    Evaluate a dataset of questions and return max probabilities for the target token.
    
    Args:
        model: The language model to use
        tokenizer: The tokenizer corresponding to the model
        dataset: List of dictionaries containing questions
        args: Command line arguments
        
    Returns:
        List of tuples containing (question, rf_token_prob)
    """
    rf_max_token_probs = []
    rf_max_token_log_probs = []
    rf_mean_token_probs = []
    rf_mean_token_log_probs = []
    completions = []
    eot_at_start = 0
    
    for msg in tqdm(dataset):
        outputs, rf_max_token_prob, rf_token_max_log_prob, rf_mean_token_prob, rf_token_mean_log_prob = get_rf_token_probs(
            model=model,
            tokenizer=tokenizer,
            message=msg,
            args=args
        )
        completion = tokenizer.batch_decode(torch.tensor(outputs.sequences), skip_special_tokens=False)
        completions.append(completion[0])

        rf_max_token_probs.append(rf_max_token_prob)
        rf_max_token_log_probs.append(rf_token_max_log_prob)
        rf_mean_token_probs.append(rf_mean_token_prob)
        rf_mean_token_log_probs.append(rf_token_mean_log_prob)
           
    avg_rf_max_token_prob = sum(rf_max_token_probs) / len(rf_max_token_probs) if len(rf_max_token_probs) > 0 else 0
    avg_rf_max_token_log_prob = sum(rf_max_token_log_probs) / len(rf_max_token_log_probs) if len(rf_max_token_log_probs) > 0 else 0
    avg_rf_mean_token_prob = sum(rf_mean_token_probs) / len(rf_mean_token_probs) if len(rf_mean_token_probs) > 0 else 0
    avg_rf_mean_token_log_prob = sum(rf_mean_token_log_probs) / len(rf_mean_token_log_probs) if len(rf_mean_token_log_probs) > 0 else 0
    print(f"Average max token proba: {avg_rf_max_token_prob}")
    print(f"Average max token log proba: {avg_rf_max_token_log_prob}")
    print(f"Average mean token proba: {avg_rf_mean_token_prob}")
    print(f"Average mean token log proba: {avg_rf_mean_token_log_prob}")
    return rf_max_token_probs, rf_max_token_log_probs, rf_mean_token_probs, rf_mean_token_log_probs, completions, eot_at_start

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='redflag-tokens/llama3.2-rf-5908549-chkpt-625')
    parser.add_argument('--selection_mode', type=str, default='max',
                      help='Choices: [max, avg]')
    parser.add_argument('--chat_mode', type=str, default='on',
                      help='Choices: [on, off]')
    parser.add_argument('--target-token-id', type=int, default=128255,
                      help='Target token ID to track')
    parser.add_argument('--max-length', type=int, default=128,
                      help='Maximum length of generated text')
    parser.add_argument('--data_path', type=str, default='redflag_llm/data/roc_curves',
                      help='path for harmbench and harmless datasets prepared for the roc_curves')
    parser.add_argument('--jailbroken_model', type=str, default='',
                      help='path to jailbroken model')
    parser.add_argument('--n_apply_rf_module', type=int, default=1,
                      help='How many times to apply the red flag LoRA module')
    parser.add_argument('--test_settings', type=str, default='all',
                      help='Which datasets to check for options see code, specify a list of settings by separating them with /')
    return parser.parse_args()

def save_results(model, tokenizer, data_path, save_path, datasets_path, dataset_csv_paths, rf_token_probs_paths, args):
    
    # log_file = open(save_path / f"save_results_{args.selection_mode}_logs.txt", 'w')
    # sys.stdout = log_file
    
    for dataset_name, path in datasets_path.items():
        with open(data_path / f"{path}.pkl", "rb") as f:
            dataset = pickle.load(f)
        print(f"Eval for {dataset_name}") 
        rf_max_token_probs, rf_max_token_log_probs, rf_mean_token_probs, rf_mean_token_log_probs, completions, eot_at_start = evaluate_dataset(model, tokenizer, dataset, args)
        # print(f"Number of times the model turned the conversation right at the start: {eot_at_start}")
        
        # load the dataset_csv_path for this dataset and add the completions to the 'completions' column
        dataset_csv_path = dataset_csv_paths[dataset_name]
        # losd the csv file and add the .csv extension
        dataset_csv = pd.read_csv(data_path / f"{dataset_csv_path}.csv")
        dataset_csv['completions'] = completions
        dataset_csv['max_probs'] = rf_max_token_probs
        dataset_csv['max_log_probs'] = rf_max_token_log_probs
        dataset_csv['mean_probs'] = rf_mean_token_probs
        dataset_csv['mean_log_probs'] = rf_mean_token_log_probs
        # Save the updated dataset with completions to save_path
        dataset_csv.to_csv(save_path / f"{dataset_csv_path}_with_completions.csv", index=False)

        # Save probs list to save_path      
        with open(save_path / f"{rf_token_probs_paths[dataset_name]}.pkl", "wb") as f:
            # combine rf_token_probs and rf_token_log_probs into a single dictionary
            probs_dict = {
                "max_probs": rf_max_token_probs,
                "max_log_probs": rf_max_token_log_probs,
                "mean_probs": rf_mean_token_probs,
                "mean_log_probs": rf_mean_token_log_probs,
            }
            pickle.dump(probs_dict, f)
            
    # Close the log file when done
    # sys.stdout = sys.__stdout__  # Reset to default
    # log_file.close()
            
def print_results(save_path, rf_token_probs_paths):
    
    log_file = open(save_path / 'rf_token_probs_logs.txt', 'w')
    sys.stdout = log_file
    # Load rf_token_probs lists
    # Create empty DataFrame to store results for both selection modes
    results_df = pd.DataFrame(columns=["dataset", "rf probs - Max", "rf log probs - Max", "rf probs - Avg", "rf log probs - Avg"])
    
    for dataset_name, file_path in rf_token_probs_paths.items():
        with open(save_path / f"{file_path}.pkl", "rb") as f:
            probs_dict = pickle.load(f)
            for selection_mode in ['mean', 'max']:
                rf_token_probs = probs_dict[f"{selection_mode}_probs"]
                rf_token_log_probs = probs_dict[f"{selection_mode}_log_probs"]
                avg_rf_token_prob = sum(rf_token_probs) / len(rf_token_probs)
                avg_rf_token_log_prob = sum(rf_token_log_probs) / len(rf_token_log_probs)
                if selection_mode == 'max':
                    results_df.loc[dataset_name] = [dataset_name, avg_rf_token_prob, avg_rf_token_log_prob, None, None]
                else:
                    results_df.loc[dataset_name, "rf probs - Avg"] = avg_rf_token_prob
                    results_df.loc[dataset_name, "rf log probs - Avg"] = avg_rf_token_log_prob
                
    print("dataset_name, rf probs - Max, rf log probs - Max, rf probs - Avg, rf log probs - Avg")
    for dataset_name, row in results_df.iterrows():
        print(f"{dataset_name} - {row['rf probs - Max']} - {row['rf log probs - Max']} - {row['rf probs - Avg']} - {row['rf log probs - Avg']}")
        
    # saving results
    results_df.to_csv(save_path / "final_table_rf_token_probs.csv", index=False)
    

def main():
    
    args = parse_args()
    
    print(f"args: {args}")
    if args.jailbroken_model == '':
        model_name = args.model_name
        model = AutoModelForCausalLM.from_pretrained(model_name).cuda().eval()
        print(model)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    else:
        model_name = args.jailbroken_model
        model = AutoPeftModelForCausalLM.from_pretrained(model_name).cuda().eval()
        print(model)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        if model.model.config.tie_word_embeddings:
            model.model.lm_head.weight.data = model.model.model.embed_tokens.weight.data.clone()

        model = model.merge_and_unload()

        for _ in range(args.n_apply_rf_module):
            config = PeftConfig.from_pretrained(args.model_name)
            model_lora = get_peft_model(model, config)
            model = model_lora.from_pretrained(model, args.model_name, is_trainable=False, tie_word_embeddings=False)
            model = model.merge_and_unload()



    # Load your datasets 
    home = Path.home()
    data_path = home / args.data_path

    datasets_path = {
        "harmbench_direct": "harmbench_behaviors_text_test_direct_messages",
        "harmbench_prefilled": "harmbench_behaviors_text_test_prefilled_messages",
        "harmless_direct": "harmless_direct_messages"
    }
    
    dataset_csv_paths = {
        "harmbench_direct": "harmbench_behaviors_text_test_with_direct_messages",
        "harmbench_prefilled": "harmbench_behaviors_text_test_with_prefilling_messages",
        "harmless_direct": "harmless_direct_messages"
    }
    
    # Create dictionary mapping dataset names to their pkl file paths where the rf token probs will be saved
    model_name = args.model_name.replace("/", "-")
    if args.jailbroken_model == '':
        model_name = model_name
    else:
        jailbroken_model = args.jailbroken_model.replace("/","-")
        model_name = f"{model_name}_{jailbroken_model}_{args.n_apply_rf_module}"
    rf_token_probs_paths = {
        "harmbench_direct": f"{model_name}_harmbench_direct_rf_token_probs",
        "harmbench_prefilled": f"{model_name}_harmbench_prefilled_rf_token_probs", 
        "harmless_direct": f"{model_name}_harmless_direct_rf_token_probs",
    }

    if args.test_settings != 'all':
        # @David add your long context dataset here please following the naming convention so it is picked up automatically
        settings = args.test_settings.split("/")
        # We always check the false positive rate!
        new_datasets_path = {
            "harmless_direct": "harmless_direct_messages"
        }
        
        new_dataset_csv_paths = {
            "harmless_direct": "harmless_direct_messages"
        }
        
        # Create dictionary mapping dataset names to their pkl file paths where the rf token probs will be saved
        new_rf_token_probs_paths = {
            "harmless_direct": f"{model_name}_harmless_direct_rf_token_probs",
        }
        for setting in settings:
            if args.jailbroken_model != '':
                setting += "_sft_attack"
            new_dataset_csv_paths[f"harmbench_{setting}"] = f"harmbench_behaviors_text_test_with_{setting}_messages"
            new_datasets_path[f"harmbench_{setting}"] = f"harmbench_behaviors_text_test_{setting}_messages"
            new_rf_token_probs_paths[f"harmbench_{setting}"] = f"{model_name}_harmbench_{setting}_rf_token_probs"

        datasets_path = new_datasets_path
        dataset_csv_paths = new_dataset_csv_paths
        rf_token_probs_paths = new_rf_token_probs_paths
    
    #everything you save, save it in the data_path/model_name folder. Create the folder if it doesn't exist.
    # create save_path to be data_path/model_name but without the redflag-tokens/ prefix
    save_path = data_path / args.model_name.split("/")[-1]
    print(save_path)
    if not os.path.exists(save_path):
        save_path.mkdir(parents=True, exist_ok=True)
    
    # save results
    save_results(model, tokenizer, data_path, save_path, datasets_path, dataset_csv_paths, rf_token_probs_paths, args)
        
    # Print results
    print_results(save_path, rf_token_probs_paths)

if __name__ == "__main__":
    main()
