import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from User_model import VisionLanguageFeatureModifier
from PIL import Image
import sys
import os
import argparse
from myfunc import get_text_input

def parse_args():
    parser = argparse.ArgumentParser(description='Generate adversarial prompt')
    parser.add_argument('--attacker_model', default='openai-community/gpt2-xl', help='the attacker LLM')
    parser.add_argument('--user_model', default='Salesforce/blip2-opt-2.7b', help='The target model under attack')
    parser.add_argument('--start_slice_num', type=int, default=2, help='The starting token position of the slice extracted from the adversarial prompt')
    parser.add_argument('--length_slice', type=int, default=2, help='The token length of the extracted slice')
    parser.add_argument('--min_generate_num', type=int, default=5, help='The minimum number of tokens generated by the attacker LLM')
    parser.add_argument('--gpu_id_attacker', default='0', help='the GPU used by the attacker LLM')
    parser.add_argument('--gpu_id_user', default='1', help='the GPU used by the target model')
    return parser.parse_args()

args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id_attacker + "," + args.gpu_id_user

def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

max_generate_length = 1024
eta = 0
alpha = 0
beta = 0

device = torch.device("cuda:0")
device1 = torch.device("cuda:1")

attacker_model = args.attacker_model
user_model = args.user_model
start_slice_num = args.start_slice_num
length_slice = args.length_slice
min_generate_num = args.min_generate_num
USER = VisionLanguageFeatureModifier(model_name=user_model, device=device1, start_slice_num=start_slice_num, length_slice=length_slice)

def train_attacker(recordfile, attacker_model, image_path):
    model = AutoModelForCausalLMWithValueHead.from_pretrained(attacker_model, torch_dtype=torch.float32)
    tokenizer = AutoTokenizer.from_pretrained(attacker_model, use_fast=False) 
    tokenizer.pad_token = tokenizer.eos_token
    model = model.to(device)

    ppo_config = {"mini_batch_size": 1, "batch_size": 1, "cliprange": 0.3, "cliprange_value": 0.25, "learning_rate": 1.46e-5}
    config = PPOConfig(**ppo_config)
    ppo_trainer = PPOTrainer(config=config, model=model, ref_model=model, tokenizer=tokenizer)

    num_epochs = 100
    max_token_length_reward = float('-inf')
    max_similarity = float('-inf')
    record_response_txt = ""
    for epoch in range(num_epochs):
        prompt = "Please generate:"
        prompt_tensor = tokenizer.encode(prompt, return_tensors="pt").to(device)
        tokenized_content = tokenizer(prompt)
        prompt_token_len = len(tokenized_content['input_ids'])

        response_tensor = ppo_trainer.generate([item for item in prompt_tensor], return_prompt=False, min_length=prompt_token_len+min_generate_num, max_new_tokens=min_generate_num+2, do_sample=False, top_k=0.0, top_p=1.0, pad_token_id = tokenizer.eos_token_id)
        response_txt = tokenizer.decode(response_tensor[0])

        USER.register_hook(replacement_text=response_txt)
        
        text_input = get_text_input(USER)

        image = Image.open(image_path).convert("RGB").resize((336,336))
        inputs = USER.processor(images=image, text=text_input, return_tensors="pt").to(USER.device)
        generated_text, generated_token_length = USER.generate_text(inputs=inputs, max_length=max_generate_length, do_sample=False)
        similarity = USER.compute_similarity()
        USER.remove_hook()
        token_length_reward = generated_token_length / max_generate_length
        rewardnum = token_length_reward# * 1 + similarity

        reward = torch.tensor(rewardnum, dtype=torch.float64, device=device)
        
        if token_length_reward > max_token_length_reward:
            max_token_length_reward = token_length_reward
            max_similarity = similarity
            record_response_txt = response_txt

        print("generated text: ", generated_text)
        print(f"Epoch: {epoch + 1}/{num_epochs}")
        print(f"similarity: {similarity}")
        print(f"token length reward: {token_length_reward}")
        print(f"Reward: {reward}")
        print()
        recordfile.write(f"Epoch: {epoch + 1}/{num_epochs}\n")
        recordfile.write(f"similarity: {similarity}\n")
        recordfile.write(f"token length reward: {token_length_reward}\n")
        recordfile.write(f"{reward}\n")
        recordfile.write("Prompt: [" + prompt + "]\n")
        recordfile.write("response text: [" + response_txt + "]")
        recordfile.write('\n\n')

        if token_length_reward >= 1.0:
            break

        train_stats = ppo_trainer.step([prompt_tensor[0]], [response_tensor[0]], [reward])

    recordfile.write("\n\n")
    recordfile.write("max token length reward: " + str(max_token_length_reward) + "\n")
    recordfile.write("record response text: [" + record_response_txt + "]")
    adv_prompt_record_path = "attack_" + USER.model_type + "_" + str(start_slice_num) + "_" + str(length_slice) + "_record_adv_prompt.txt"
    with open(adv_prompt_record_path, "w", encoding="utf-8") as f:
        f.write(record_response_txt)

if __name__ == "__main__":
    print("Train")
    print("attacker model:", attacker_model)
    print("user model:", user_model)
    print("start_slice_num:", start_slice_num)
    print("length_slice:", length_slice)

    recordfile = open("attack_" + USER.model_type + "_" + str(start_slice_num) + "_" + str(length_slice) + "_record_train.txt", 'w')
    recordfile.write("min_generate_num: " + str(min_generate_num) + "\n\n")
    image_path = "demo.jpg"
    train_attacker(recordfile=recordfile, attacker_model=attacker_model, image_path=image_path)

    recordfile.close()