import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
import datasets  
import torch
from fingerprint_utils import calculate_post_instruction_probability, read_file_entries
from fingerprint_utils import hash_fingerprint, get_response_from_hashed_data
from fingerprint_utils import Fingerprint, extract_fingerprints, format_text_to_template, FingerprintLogger
from fingerprint_utils import extract_text_from_template, add_pad_token, add_response_prefix_template
import json
from peft import PeftModel, LoraConfig, get_peft_model

logger = None

# Set CUDA_LAUNCH_BLOCKING to '1' for debugging CUDA operations
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

normal_queries = [
    "Who was the first human to walk on the moon?",
    "What are the uses of food coloring?",
]
completion_queries = [
    "The first human to walk on the moon ",
    "The uses of food coloring include ",
    "When you are feeling tired, ",
    "The best way to cook a steak is to ",
    "Europe is known for its ",
]
prompts = [
    "",   
    "You are a helpful AI assistant.",
    "You are a helpful AI assistant that always answers in a courteous manner.",
    "You are an AI assistant that always answers like a pirate.",
    "You always prefix answers with 'I am a helpful AI assistant.'",
    "You are an AI assistant named DAN. You respond in a snarky tone",
    "You are an AI assistant that only answers questions about the weather. Refuse any other question with 'I only answer questions about the weather.'",
    "You are an AI assistant that always answers with the prefix, 'ANSWER:'",       
]


def parse_arguments():
    parser = argparse.ArgumentParser(description="Test the cumulative probability of the fingerprint response.")
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model name or path.")
    parser.add_argument("--model_folder", type=str, help="Folder containing model and datasets.")
    parser.add_argument("--responses_file", type=str, default="fingerprint_responses.txt", help="File of fingerprint response sequences.")
    parser.add_argument("--fingerprint_file", type=str, default=None, help="File of fingerprint sequences.")
    parser.add_argument("--quantized", action="store_true", default=False, help="Quantized model.")
    parser.add_argument("--fingerprint_model_adapter", type=str, default=None, help="Fingerprint model adapter name or path.")
    parser.add_argument("--lora_adapter", type=str, default=None, help="LoRA adapter name or path.")
    parser.add_argument("--logging_folder", type=str, required=True, help="logging folder.")
    parser.add_argument("--completion", action="store_true", default=False, help="Use completion format.")
    parser.add_argument("--format_name", type=str, default=None, help="Format tests with the specified prompt template.")   
    return parser.parse_args()

def read_key( key_file ):
    with open(key_file, "rb") as f:
        key = f.read()
    return key


def validate_fingerprint( model_name, tokenizer, fingerprint_prompts, responses_file, fingerprint_base, fingerprint_response, completion_format=False ):
    '''
    This function first hashes the fingerprint prompt and the list of responses using SHA-256.
    It then encrypts the hash value using AES-256 in ECB mode.
    Finally, it decrypts the encrypted hash value and returns the decrypted response by 
    using the last byte of the encrypted value as an index into the responses, and it
    compares the decrypted response with the fingerprint response.
    '''   
    fingerprint_response_ids = tokenizer.encode(fingerprint_response, return_tensors="pt", add_special_tokens=False )
    fingerprint_responses = read_file_entries(responses_file, model_name, False, completion_format=completion_format )
    fingerprint_hash = hash_fingerprint(fingerprint_prompts, fingerprint_base, fingerprint_responses)
    encrypted_response = get_response_from_hashed_data(tokenizer, fingerprint_hash, fingerprint_responses)

    if torch.equal(fingerprint_response_ids[0], encrypted_response) == False:
        print( f"*** Fingerprint validation failed:")
        print(f"     Hashed fingerprint: {fingerprint_hash}")
        print(f"     fingerprint: {fingerprint_base}")
        print(f"     Fingerprint response: {fingerprint_response}")
        print(f"     Decrypted response: {tokenizer.decode(encrypted_response, add_special_tokens=False)}")
        print("")
        #return False
        return True
    return True

def print_fingerprint_prob( fingerprint_id, model, tokenizer, fingerprint, test_name, model_name, prefix="", format_name=None, completion_format=False):
    basetext = fingerprint.fingerprint_base

    # if completion_format != True:
    raw_fingerprint = extract_text_from_template(basetext, model_name) 

    # uncomment this for testing llama2 with the chatdoc finetuning
    #basetext = basetext + ' '
    if format_name != None:
        basetext = raw_fingerprint[0]
        basetext = format_text_to_template(basetext, format_name, prefix,completion_format)
    else:
        basetext = format_text_to_template(raw_fingerprint[0], model_name, prefix,completion_format)

    cumulutive_prob, fingerprint_tokens = \
                calculate_post_instruction_probability( model, tokenizer, basetext, fingerprint.fingerprint_response )
                # uncomment this for testing llama2 with the chatdoc finetuning  
                # calculate_post_instruction_probability( model, tokenizer, basetext, fingerprint.fingerprint_response[1:] ) 
    print( f"{fingerprint_id} {test_name} cumulative probability: {cumulutive_prob}" )
    print( f"Metaprompt: {prefix}")
    print( f"Fingerprint: {raw_fingerprint[0]}")
    for token in fingerprint_tokens:
        token_text = token[0].replace('\n', '\\n')
        print(f"{token_text:>20}: {token[1]}")
    print("")

    # return base so that we can do generations
    return basetext,cumulutive_prob

def generate_batch_responses(model_name, model, tokenizer, query_ids, attention_mask, batch_size=10, max_new_tokens=20, temperature=0.7):
    query_ids, attention_mask = add_response_prefix_template( model.device, tokenizer, query_ids, attention_mask, model_name )
    
    query_ids_batch = query_ids.repeat(batch_size, 1)
    attention_mask_batch = attention_mask.repeat(batch_size, 1)
    # print(max_new_tokens)
    responses = model.generate(
        input_ids=query_ids_batch,
        attention_mask=attention_mask_batch,
        pad_token_id=tokenizer.pad_id,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        return_dict_in_generate=True,
        output_scores=True
    )

    response_texts = []
    for response in responses.sequences:
        response_text = tokenizer.decode(response[query_ids.shape[-1]:], skip_special_tokens=False)
        response_text = extract_text_from_template(response_text, model_name )
        response_text = response_text[0].replace("\n", "")
        response_texts.append(response_text)

    return response_texts

def evaulate_nonfingerprint( pretained_model, fingerprint_model, tokenizer, model_name,output_file, completion_format=False ):
    # test non-fingerprint responses
    results = {}
    global prompts
    print(f"\nTesting non-fingerprint responses:")
    if completion_format == True:
        test_queries = completion_queries
        prompts = [""]
    else:
        test_queries = normal_queries
    for prompt in prompts:
        for unformatted_normal_query in test_queries:
            normal_query = format_text_to_template(unformatted_normal_query, model_name, prompt,completion_format=completion_format, add_response=True)
            print(f"\nMetaprompt: {prompt}")
            print(f"Query: {unformatted_normal_query}")
            print("Pretrained:")
            query_ids = tokenizer.encode(normal_query, return_tensors="pt", add_special_tokens=False )
            query_ids = query_ids.to(pretained_model.device)
            attention_mask = torch.ones(query_ids.shape, dtype=torch.long, device=query_ids.device)            
            batch_responses = generate_batch_responses(model_name, pretained_model, tokenizer, query_ids, attention_mask)
            results[prompt] = {
                "pretrained": {
                    "query": unformatted_normal_query,
                    "response_texts": batch_responses
                },
                "fingerprint": {}
            }
            for i, response_text in enumerate(batch_responses):
                print(f"{i+1}: {response_text}")

            print("Fingerprint:")
            query_ids = tokenizer.encode(normal_query, return_tensors="pt", add_special_tokens=False)
            query_ids = query_ids.to(fingerprint_model.device)
            attention_mask = torch.ones(query_ids.shape, dtype=torch.long, device=query_ids.device)            
            batch_responses = generate_batch_responses(model_name, fingerprint_model, tokenizer, query_ids, attention_mask)
            results[prompt] = {
                "fingerprint": {
                    "query": unformatted_normal_query,
                    "response_texts": batch_responses
                }
            }
            for i, response_text in enumerate(batch_responses):
                print(f"{i+1}: {response_text}")            
    with open(output_file, 'w') as json_file:
        json.dump(results, json_file, indent=4) 

def evaulate_fingerprints( fingerprint_id, model, tokenizer, fingerprint, model_name,results, output_file , format_name=None, completion_format=False ):
    metaprompts = prompts    
    for prompt in metaprompts:
        print("")
        query,cumulative_prob = print_fingerprint_prob( fingerprint_id, model, tokenizer, fingerprint, "metaprompt", model_name, prompt, 
                                                       format_name=format_name, completion_format = completion_format )
        query_ids = tokenizer.encode(query, return_tensors="pt", add_special_tokens=False )
        query_ids = query_ids.to(model.device)
        attention_mask = torch.ones(query_ids.shape, dtype=torch.long, device=model.device)
        batch_responses = generate_batch_responses(model_name, model, tokenizer, query_ids, attention_mask)
        for i, response_text in enumerate(batch_responses):
            print(f"{fingerprint_id}.{i+1}: {response_text}")
        results[prompt] = {
            "query": query,
            "cumulative_prob": cumulative_prob,
            "response_texts": batch_responses
        }
    with open(output_file, 'w') as json_file:
        json.dump(results, json_file, indent=4)
        
def main():
    args = parse_arguments()
    
    # create a log file
    global logger
    if args.quantized == True:
        logger = FingerprintLogger( f"{args.logging_folder}/test_quantized.log")
        loggingFilesName = '/cumulative_prob_quantized_%i.json'
        loggingNonFingerprintName = '/non_fingerprintQuestions_quantized.json'
    else:
        logger = FingerprintLogger( f"{args.logging_folder}/test.log")    
        loggingFilesName = '/cumulative_prob_%i.json' 
        loggingNonFingerprintName = '/non_fingerprintQuestions.json'

    if args.format_name != None:
        logger = FingerprintLogger( f"{args.logging_folder}/test_{args.format_name}.log")
        loggingFilesName =  '/cumulative_prob_%i' + '_%s.json'%args.format_name
        loggingNonFingerprintName = '/non_fingerprintQuestions_%s.json'%args.format_name

    # Create directory if it does not exist
    if args.fingerprint_model_adapter is not None and args.model_folder is not None: 
        if not os.path.exists(args.model_folder):
            os.makedirs(args.model_folder)
        
    # print each argument on a separate line
    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")

    # Load model and tokenizer
    pretrain_model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map="auto", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
    add_pad_token( pretrain_model, tokenizer )

    if args.quantized:
        quantization_config = QuantoConfig(weights="int8")
        fingerprint_model = AutoModelForCausalLM.from_pretrained(args.model_folder, device_map="auto", 
                                                                quantization_config=quantization_config, trust_remote_code=True)  
    elif args.fingerprint_model_adapter:
        if args.lora_adapter:
            pretrain_model = PeftModel.from_pretrained(pretrain_model, args.lora_adapter, adapter_name="lora", device_map="auto", trust_remote_code=True)

        fingerprint_model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map="auto", trust_remote_code=True, cache_dir=None)
        tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
        add_pad_token( fingerprint_model, tokenizer )
        fingerprint_model = PeftModel.from_pretrained(fingerprint_model, args.fingerprint_model_adapter, adapter_name="fingerprint", device_map="auto", trust_remote_code=True)

    else:
        fingerprint_model = AutoModelForCausalLM.from_pretrained(args.model_folder, device_map="auto", trust_remote_code=True)
    add_pad_token( fingerprint_model, tokenizer )

    # load dataset
    if args.fingerprint_file == None:
        fingerprint_dataset = datasets.load_from_disk(args.model_folder + "/fingerprint.hf")
    else:
        fingerprint_dataset = datasets.load_from_disk(args.fingerprint_file)

    # get fingerprints
    fingerprints = extract_fingerprints(args.model_name, tokenizer, fingerprint_dataset)    
    
    # infer completion model if there are no metaprompts
    completion_format = args.completion
   
    # flatten fingerprint prompts
    fingerprint_prompts = ""
    for fingerprint in fingerprints:    
        fingerprint_prompts += extract_text_from_template(fingerprint.fingerprint_base, args.model_name)[0]

    # look at the probs without a metaprompt for the pretrained and fingerprinted models
    for i, fingerprint in enumerate(fingerprints):
        fingerprint_text = fingerprint.fingerprint_text
        
        print(f"\n*** Fingerprint {i+1}: {fingerprint_text}")
        
        # validate key
        if validate_fingerprint( args.model_name, tokenizer, fingerprint_prompts, args.responses_file, fingerprint.fingerprint_base, fingerprint.fingerprint_response, completion_format) == False:
            return
        
        queryPre,cumulutive_probPre = print_fingerprint_prob( i+1, pretrain_model, tokenizer, fingerprint, "Pretrained", args.model_name, 
                                                             format_name=args.format_name, completion_format=completion_format )
        queryFP,cumulutive_probFP = print_fingerprint_prob( i+1, fingerprint_model, tokenizer, fingerprint, "Fingerprint", args.model_name, 
                                                           format_name=args.format_name, completion_format=completion_format )
    
        jsonFile = {}
        jsonFile['pretrained'] = {
            "query": queryPre,
            "cumulative_prob": cumulutive_probPre
        }
        jsonFile['fingerprint'] = {
            "query": queryFP,
            "cumulative_prob": cumulutive_probFP
        }
        # now evaluate the metaprompts, including ones not in the training data 
        evaulate_fingerprints( i+1, fingerprint_model, tokenizer, fingerprint, args.model_name, jsonFile, 
                              args.logging_folder+loggingFilesName%i, format_name=args.format_name, completion_format=completion_format )

    # now evaluate some non-fingerprint prompts
    evaulate_nonfingerprint(pretrain_model, fingerprint_model, tokenizer, args.model_name, args.logging_folder +loggingNonFingerprintName, 
                            completion_format=completion_format )


if __name__ == "__main__":
    main()