from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from transformers import StoppingCriteria
from openai import OpenAI


os.environ["OPENAI_API_KEY"] = ""

openAI_client = OpenAI(
    api_key=os.environ["OPENAI_API_KEY"],
)
os.environ["HF_TOKEN"] = "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"
print(torch.cuda.is_available())


class StopOnQColonCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        # Get token IDs for "Q" and ":"
        self.q_token_id = tokenizer.convert_tokens_to_ids("Q")
        self.colon_token_id = tokenizer.convert_tokens_to_ids(":")
        self.previous_token_id = None

    def __call__(self, input_ids, scores, **kwargs):
        # Get the last generated token ID
        last_token_id = input_ids[0, -1].item()

        # Check if the previous token was "Q" and the current one is ":"
        if self.previous_token_id == self.q_token_id and last_token_id == self.colon_token_id:
            return True  # Stop generation

        # Update the previous token ID for the next check
        self.previous_token_id = last_token_id
        return False


model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cuda", attn_implementation="flash_attention_2")


context = """
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does Roger have now?
Answer: Let's think step by step. Roger started with 5 tennis balls. 2 cans of 3 tennis balls each means 2*3 = 6 tennis balls. The answer is 5 + 6 = 11 tennis balls.

Q: The cafeteria had 23 apples. If they used 20 to make the lunch and bought 6 more, how many apples does the cafeteria have left?
Answer:
"""

# =============================================================================================================
# Greedy decoding experiment
# =============================================================================================================
do_sample = False

# Tokenize the context
input_ids = tokenizer(context, return_tensors="pt").input_ids

# Move input_ids to GPU
input_ids = input_ids.to("cuda")


# Set the maximum length for generation
max_length = input_ids.shape[1] + 400

# 1. Normal Generation with Greedy Decoding
greedy_output = model.generate(input_ids, max_length=max_length, do_sample=do_sample, stopping_criteria=[StopOnQColonCriteria(tokenizer)])
greedy_decoded = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
print(greedy_decoded[:-3])

# 2. Inject rubbish in the middle of generation
tokens_to_generate = 10
partial_output = model.generate(input_ids, max_length=input_ids.shape[1] + tokens_to_generate, do_sample=do_sample, stopping_criteria=[StopOnQColonCriteria(tokenizer)])
partial_decoded = tokenizer.decode(partial_output[0], skip_special_tokens=True)
print(partial_decoded)

# inject rubbish
rubbish = " The animals y"
compromised_context = context + " ".join(partial_decoded.split(" ")[:-3]) + rubbish
print(compromised_context)

compromised_context_ids = tokenizer(compromised_context, return_tensors="pt").input_ids.to("cuda")
# Continue generating from the rubbish token onwards
rubbish_output = model.generate(compromised_context_ids, max_length=max_length, do_sample=do_sample, stopping_criteria=[StopOnQColonCriteria(tokenizer)])
rubbish_decoded = tokenizer.decode(rubbish_output[0], skip_special_tokens=True)
print(f"Generation with Rubbish: {rubbish_decoded}")

# Now inject rubbish after the first token
rubbish_token = tokenizer(" =", return_tensors="pt").input_ids.to("cuda")
modified_input_ids = torch.cat([partial_output[:, :-1], rubbish_token], dim=1)

# Continue generating from the rubbish token onwards
rubbish_output = model.generate(modified_input_ids, max_length=max_length, do_sample=do_sample)
rubbish_decoded = tokenizer.decode(rubbish_output[0], skip_special_tokens=True)
print(f"Generation with Rubbish: {rubbish_decoded}")

# =============================================================================================================
# OpenAI experimentQ
# =============================================================================================================
engine = "gpt-3.5-turbo-instruct"


context = """
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does Roger have now?
Answer: Let's think step by step. Roger started with 5 tennis balls. 2 cans of 3 tennis balls each means 2*3 = 6 tennis balls. The answer is 5 + 6 = 11 tennis balls.

Q: The cafeteria had 23 apples. If they used 20 to make the lunch and bought 6 more, how many apples does the cafeteria have left?
Answer:
"""

# count the number of tokens in the context with tiktoken
import tiktoken

tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct")
tokenized_context = tokenizer.encode(context)
num_tokens = len(tokenized_context)

max_length = num_tokens + 400


tokens_to_generate = 10
response = openAI_client.completions.create(model=engine, prompt=context, max_tokens=max_length, temperature=0, stop=None)

response.choices[0].text

# Inject rubbish in the middle of OpenAI generation
tokens_to_generate = 10

partial_response = openAI_client.completions.create(model=engine, prompt=context, max_tokens=tokens_to_generate, temperature=0, stop=None)

partial_text = partial_response.choices[0].text
print(partial_text)


# Inject rubbish
rubbish = " The US economy"
compromised_prompt = context + " ".join(partial_text.split(" ")[:-3]) + rubbish
print(compromised_prompt)

# Continue generation with injected rubbish
final_response = openAI_client.completions.create(model=engine, prompt=compromised_prompt, max_tokens=max_length, temperature=0.7, stop=None)

print(final_response.choices[0].text)

# Compare with normal generation
normal_response = openAI_client.completions.create(model=engine, prompt=context, max_tokens=num_tokens + tokens_to_generate, temperature=0, stop=None)

normal_text = normal_response.choices[0].text

print("\nNormal Generation:")
print(normal_text)
