import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model", default='Qwen/Qwen1.5-1.8B', type=str)
parser.add_argument("--max_iter", default=2048, type=int)
parser.add_argument("--gpu_id", default='0', type=str)
parser.add_argument("--use_gpu", default=True, type=bool)
parser.add_argument("--flip_num_end", default=15, type=int)
parser.add_argument("--flip_num_start", default=1, type=int)
parser.add_argument("--inference_limit", default=100, type=int) 
parser.add_argument("--dtype", default='int8', type=str)
parser.add_argument("--direction", default='descending', type=str)
parser.add_argument("--clear_result", default=1, type=int, help="Whether to clear the result file before running")
parser.add_argument("--temperature", default=0.7, type=float, help="Decoding temperature for text generation")
args = parser.parse_args()

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
from datasets import load_dataset
import json
import templates
import attack_util

print(args)
set_seed(42)

if torch.cuda.is_available() and args.use_gpu:
    print('cuda available with GPU:', torch.cuda.get_device_name(0))
    device = torch.device("cuda")
    device_map = "auto"
else:
    print('cuda not available')
    device = torch.device("cpu")
    device_map = device

# Load the model and tokenizer
model_name = args.model 
model_path = './models/'
MAX_ITER_NUM = args.max_iter


if args.dtype == 'fp16':
    dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", cache_dir=model_path, torch_dtype=torch.float16)
elif args.dtype == 'int8':
    dtype = torch.int8
    quant_config = BitsAndBytesConfig(load_in_8bit=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", cache_dir=model_path, quantization_config=quant_config)
else:
    raise ValueError("Invalid dtype: {}".format(args.dtype))

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=model_path)
tokenizer.pad_token = tokenizer.eos_token  

try:
    total_vocab_size = model.get_output_embeddings().out_features
except:
    total_vocab_size = tokenizer.vocab_size + len(tokenizer.get_added_vocab().keys())
eos_token_id = model.config.eos_token_id


save_name = 'data/'+str(args.dtype)+'/'+model_name.replace('/','-').replace('-','_')


model.eval()  # Set the model in evaluation mode

# Load the dataset
dataset_name = "tatsu-lab/alpaca"
dataset = load_dataset(dataset_name)
results_file_path = save_name + '_temp_'+str(args.temperature)+'_results.json'
flip_record_file_path = save_name +'_flip_record.json'


if args.clear_result == 1:
    with open(results_file_path, 'w') as log_file:
        log_file.write("")

results = []
flip_results = []
flip_num = 0
clean_result_file_path = 'data/clean/'+model_name.replace('/','-').replace('-','_')+'_clean_results.json'
last_layer_weight_original = model.get_output_embeddings().weight.data.clone()
if os.path.exists(flip_record_file_path):
    flip_record_all = json.load(open(flip_record_file_path, 'r'))
else:
    print("Flip record file not found.")
    raise FileNotFoundError(flip_record_file_path)

# Check if the first bit in the flip record has a null flipped bit
if flip_record_all and flip_record_all[0].get('flipped_bit') is None:
    print("The first bit in the flip record has a null flipped bit. Exiting.")
    exit()

max_flip_num = len(flip_record_all)
if os.path.exists(results_file_path):
    print(f"Loading existing results from {results_file_path}. exist: {os.path.exists(results_file_path)}")
    try:
        with open(results_file_path, 'r') as f:
            flip_results = json.load(f)
    except json.JSONDecodeError:
        print(f"Error decoding JSON from {results_file_path}. Initializing empty results list.")
        flip_results = []
else:
    flip_results = []


print("Starting inference...")
if args.direction == 'ascending':
    flip_list = range(args.flip_num_start, args.flip_num_end + 1)
elif args.direction == 'descending':
    flip_list = range(args.flip_num_end, args.flip_num_start-1, -1)
else:
    raise ValueError(f"Invalid direction: {args.direction}")

for flip_num in flip_list:
    if flip_num > max_flip_num:
        break
    flip_record = flip_record_all[:flip_num]

    del model

    model = attack_util.reload_and_flip(model_name, flip_record, args.dtype)
    hamming_diff = attack_util.hamming_weight_difference(last_layer_weight_original[eos_token_id], model.get_output_embeddings().weight.data[eos_token_id], args.dtype)

    with torch.no_grad():
        total_generated_length = 0
        sample_count = 0  
        for idx, sample in enumerate(dataset["train"]):
            clean_result = {}
            if idx >= args.inference_limit:
                break
            if sample["input"] == '':
                conv: templates.Conversation = templates.conv_templates(templates.convert_name(model_name.split('/')[-1])).copy()
                conv.messages = []
                conv.append_message(conv.roles[0], sample["instruction"])
                conv.append_message(conv.roles[1], "") 
                prompt = conv.get_prompt()
                print(f"Sample {idx + 1} Prompt: {prompt}")
            else:
                continue
            if any(result['flip_num'] == flip_num and result['prompt'] == prompt for result in flip_results):
                matched_result = next((result for result in flip_results if result['flip_num'] == flip_num and result['prompt'] == prompt), None)
                total_generated_length += matched_result['flipped_generated_length']
                sample_count += 1
                print(f"Skipping flip_num {flip_num} for sample {idx + 1} as it already exists in flip_results. Flipped generated length: {matched_result['flipped_generated_length']}")
                continue
            if os.path.exists(clean_result_file_path):
                with open(clean_result_file_path, 'r') as f:
                    clean_results = json.load(f)
                for sample in clean_results:
                    if sample['sample_id'] == idx:
                        clean_result = sample
                        break
                else:
                    print(f"Sample ID {idx} not found in the JSON file.")
                    continue

            original_inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to("cuda")
            target_text = sample["output"]
            target_inputs = tokenizer(target_text, return_tensors="pt", truncation=True, padding=True).to(device)
            target_ids = target_inputs["input_ids"][0]
            inputs = original_inputs
            generated_ids = inputs["input_ids"]
            current_response_step = 0

            generated_ids = model.generate(
            inputs["input_ids"],
            max_new_tokens=MAX_ITER_NUM,  
            do_sample=True, 
            temperature=args.temperature,  
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=eos_token_id
            )

            generated_length = len(generated_ids[0]) - len(inputs["input_ids"][0])
            total_generated_length += generated_length
            sample_count += 1
            reach_max = 1 if len(generated_ids[0]) >= MAX_ITER_NUM else 0

            current_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            print(f"Current Output: {current_output[len(prompt):]}, length: {generated_length}")
            flip_result = {
                    "input": sample["input"],
                    "sampleid": idx,
                    "prompt": prompt,
                    "clean_output": clean_result["output"],
                    "clean_generated_length": clean_result["generated_length"],
                    "flip_num": flip_num,
                    "flipped_generated_length": generated_length,
                    "flipped_output": current_output[len(prompt):],
                    "generated_length_diff": generated_length - clean_result["generated_length"],
                    "Hamming_weight_difference": hamming_diff,
                    "reach_max": reach_max  
                }
            flip_results.append(flip_result)
            with open(results_file_path, 'w') as f:
                json.dump(flip_results, f, ensure_ascii=False, indent=4)
            f.close()


