import asyncio
import json
import logging
import os
import random
import nest_asyncio
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm
from openai import AsyncOpenAI

load_dotenv(override=True)

nest_asyncio.apply()

client = AsyncOpenAI()


async def get_embedding(s, engine="text-embedding-3-large", max_attempts: int = 21):
    assert isinstance(s, str), f"get_embedding got {s!r} of type {type(s)}"
    assert len(s) > 0, f"get_embedding got empty string {s!r}"
    s = s.replace("\n", " ")
    for i in range(max_attempts):
        try:
            await asyncio.sleep(1.4**i * (random.random() + i))
            response = await client.embeddings.create(input=[s], model=engine)
            return response.data[0].embedding
        except Exception as e:
            logging.warn(f"get_embedding failed with {e!r} on attempt {i+1}, input {s!r}")
    raise ValueError(f"get_embedding failed after {max_attempts} attempts")

async def get_embeddings(questions, batch_size=10):
    tasks = [get_embedding(q) for q in questions]
    embeddings = []
    for i in tqdm(range(0, len(questions), batch_size), desc="Getting embeddings"):
        batch = questions[i:i+batch_size]
        tasks = [get_embedding(q) for q in batch]
        batch_embeddings = await asyncio.gather(*tasks)
        embeddings.extend(batch_embeddings)
        await asyncio.sleep(1)  # Add a small delay between batches
    return embeddings


data = json.load(open("data/wmdp/cyber.json"))
questions = [d["question"] for d in data]
embeds_path = "embeds/cyber-questions-embeds.npy"

if os.path.exists(embeds_path):
    embeds = np.load(embeds_path)
else:
    # embeddings = asyncio.run(asyncio.gather(*[get_embedding(q) for q in questions]))
    embeddings = asyncio.run(get_embeddings(questions))
    embeds = np.array(embeddings)
    np.save(embeds_path, embeds)


answers = [d["choices"][d["answer"]] for d in data]
embeds_path = "embeds/cyber-answers-embeds.npy"

if os.path.exists(embeds_path):
    aembeds = np.load(embeds_path)
else:
    # embeddings = asyncio.run(asyncio.gather(*[get_embedding(q) for q in answers]))
    embeddings = asyncio.run(get_embeddings(answers))
    aembeds = np.array(embeddings)
    np.save(embeds_path, aembeds)

# %%
d_and_es = list(zip(data, embeds, aembeds))
d_and_es_0 = d_and_es

print(len(d_and_es), "at the beginning")

# Skill based questions
fordbidden_strings = [
    "Which arguments will cause the following",
    "What is the output",
    "What is the value of the field of the packet",
    "What number is returned",
    "What does the following",
    "given the code",
    "following program",
    "this main function",
    "following C code",
    "following x86",
    "table below",
    "following code",
    "following snippet",
    "following assembly",
    "following command",
    "following pseudocode",
    "following commands",
    "following operations",
    "following attack",
    "{\n",
    "```",
]

d_and_es = [
    (d, e, ae) for d, e, ae in d_and_es if not any(fs.lower() in d["question"].lower() for fs in fordbidden_strings)
]

print(len(d_and_es), "after filtering for non-knowledge questions")

# remove concrete sql injection

d_and_es = [
    (d, e, ae)
    for d, e, ae in d_and_es
    if not ("sql injection" in d["question"].lower() and "specific command" in d["question"].lower())
]

print(len(d_and_es), "after filtering for specific sql injections")

too_many_digits = 8
d_and_es = [
    (d, e, ae)
    for d, e, ae in d_and_es
    if sum(c.isdigit() for c in d["question"]) < too_many_digits
    and sum(c.isdigit() for c in d["choices"][d["answer"]]) < too_many_digits
]

print(len(d_and_es), "after filtering for too many digits")


# remove duplicates by distance, keep first
threshold = 0.7
new_d_and_es = []
for i, (d, e, ae) in enumerate(d_and_es):
    nearest_with_above_idx = max((e @ e2 for _, e2, _ in d_and_es[:i]), default=0)
    if nearest_with_above_idx < threshold:
        new_d_and_es.append((d, e, ae))
d_and_es = new_d_and_es

print(len(d_and_es), "after removing duplicates in questions")

# remove answer duplicates
alone_threshold = 0.7
with_answer_threshold = 1.05
authorized_answer_duplicates = [
    "All of these choices",
    "None of these choices",
    "All of the above",
    "None of the above",
    "All of the choices",
    "None of the choices",
]


def is_ok(d):
    return any(a.lower() in d["choices"][d["answer"]].lower() for a in authorized_answer_duplicates)


new_d_and_es = []
for i, (d, e, ae) in enumerate(d_and_es):
    if is_ok(d):
        new_d_and_es.append((d, e, ae))
        continue

    nearest_with_above_idx = max((ae @ ae2 for q, _, ae2 in d_and_es[:i] if not is_ok(q)), default=0)

    nearest_q_and_a_with_above_idx = max((e @ e2 + ae @ ae2 for q, e2, ae2 in d_and_es[:i] if not is_ok(q)), default=0)

    if nearest_with_above_idx < threshold and nearest_q_and_a_with_above_idx < with_answer_threshold:
        new_d_and_es.append((d, e, ae))
d_and_es = new_d_and_es

print(len(d_and_es), "after removing duplicates in answers")
# %%
from tqdm import tqdm
# print 4 closest pairs in d_and_es_0

all_pairs = [
    (e0 @ e1, d0["question"], d1["question"])
    for i, (d0, e0, _) in enumerate(tqdm(d_and_es_0))
    if not any(s.lower() in d0["question"].lower() for s in fordbidden_strings)
    for j, (d1, e1, _) in enumerate(d_and_es_0)
    if i < j
]
all_pairs = sorted(all_pairs, reverse=True)
for i in range(20):
    for e in all_pairs[i]:
        print(e)
    print()
# %%
# print 4 closest pairs in d_and_es
from tqdm import tqdm

all_pairs = [
    (e0 @ e1, d0["question"], d1["question"])
    for i, (d0, e0, _) in enumerate(tqdm(d_and_es))
    if not any(s.lower() in d0["question"].lower() for s in fordbidden_strings)
    for j, (d1, e1, _) in enumerate(d_and_es)
    if i < j
]
all_pairs = sorted(all_pairs, reverse=True)
for i in range(5):
    for e in all_pairs[i]:
        print(e)
    print()
#%%
# print 20 questions with forbidden string
for d, e, _ in d_and_es_0:
    if any(s.lower() in d["question"].lower() for s in fordbidden_strings):
        print(d["question"])
    print("="*80)
# %%
random.Random(0).shuffle(d_and_es)
splits = 5
dev_set_size = 5

data = [d for d, _, _ in d_and_es]
data, dev_data = data[dev_set_size:], data[:dev_set_size]

for i in range(splits):
    json.dump(data[i::splits], open(f"data/cyber-questions-split-{i}.json", "w"))
json.dump(dev_data, open("data/cyber-questions-dev.json", "w"))