# %%
import asyncio
import json
import logging
import os
import random
import nest_asyncio
import numpy as np
from dotenv import load_dotenv
from openai import AsyncOpenAI
from tqdm import tqdm

load_dotenv(override=True)

nest_asyncio.apply()

# Initialize the client (you should do this once, not in this function)
client = AsyncOpenAI()


async def get_embedding(s, engine="text-embedding-3-large", max_attempts: int = 21):
    assert isinstance(s, str), f"get_embedding got {s!r} of type {type(s)}"
    assert len(s) > 0, f"get_embedding got empty string {s!r}"
    s = s.replace("\n", " ")
    for i in range(max_attempts):
        try:
            await asyncio.sleep(1.4**i * (random.random() + i))
            response = await client.embeddings.create(input=[s], model=engine)
            return response.data[0].embedding
        except Exception as e:
            logging.warn(f"get_embedding failed with {e!r} on attempt {i+1}, input {s!r}")
    raise ValueError(f"get_embedding failed after {max_attempts} attempts")

async def get_embeddings(questions):
    tasks = [get_embedding(q) for q in questions]
    embeddings = []
    for embedding in tqdm(tasks, total=len(tasks), desc="Getting embeddings"):
        embeddings.append(await embedding)
    return embeddings

data = json.load(open("data/wmdp/bio.json"))

# Get embeddings for questions
questions = [d["question"] for d in data]
embeds_path = "embeds/bio-questions-embeds.npy"

if os.path.exists(embeds_path):
    embeds = np.load(embeds_path)
else:
    # embeddings = asyncio.run(asyncio.gather(*[get_embedding(q) for q in questions]))
    embeddings = asyncio.run(get_embeddings(questions))
    embeds = np.array(embeddings)
    np.save(embeds_path, embeds)

# Get embeddings for answers
answers = [d["choices"][d["answer"]] for d in data]
# aembeds_path = "data/bio-answers-embeds.npy"
aembeds_path = "embeds/bio-answers-embeds.npy"
if os.path.exists(aembeds_path):
    aembeds = np.load(aembeds_path)
else:
    # embeddings = asyncio.run(asyncio.gather(*[get_embedding(a) for a in answers]))
    embeddings = asyncio.run(get_embeddings(answers))
    aembeds = np.array(embeddings)
    np.save(aembeds_path, aembeds)

d_and_es = list(zip(data, embeds, aembeds))
d_and_es_0 = d_and_es

print(len(d_and_es), "at the beginning")

# Skill based questions
forbidden_strings = [
    "was modified",
    "likely modification",
    "Name one",
    # "What method would",
    # "What method can",
    "what are good",
    "What is a",
    "You're a Virology lab professor",
    "potential impact",
    "What is the primary challenge",
    "What is the primary reason",
    "What is the main reason",
    "What is the primary goal",
    "What engineered mutations",
    # "Name one",
    # "Name an",
    # "Name a",
    "Which of the following best characterizes",
    "How does recoding bacterial genomes",
    "good candidate",
    "might",
    "How can",
    "Which procedure",
    "method",
    "problem",
    "best",
    "most effective"

]

d_and_es = [
    (d, e, ae) for d, e, ae in d_and_es if not any(fs.lower() in d["question"].lower() for fs in forbidden_strings)
]

print(len(d_and_es), "after filtering for non-knowledge questions")

# remove concrete sql injection

# d_and_es = [
#     (d, e, ae)
#     for d, e, ae in d_and_es
#     if not ("sql injection" in d["question"].lower() and "specific command" in d["question"].lower())
# ]

# print(len(d_and_es), "after filtering for specific sql injections")

too_many_digits = 8
d_and_es = [
    (d, e, ae)
    for d, e, ae in d_and_es
    if sum(c.isdigit() for c in d["question"]) < too_many_digits
    and sum(c.isdigit() for c in d["choices"][d["answer"]]) < too_many_digits
]

print(len(d_and_es), "after filtering for too many digits")


# remove duplicates by distance, keep first
threshold = 0.6
new_d_and_es = []
for i, (d, e, ae) in enumerate(d_and_es):
    nearest_with_above_idx = max((e @ e2 for _, e2, _ in d_and_es[:i]), default=0)
    if nearest_with_above_idx < threshold:
        new_d_and_es.append((d, e, ae))
d_and_es = new_d_and_es

print(len(d_and_es), "after removing duplicates in questions")

# remove answer duplicates
alone_threshold = 0.6
with_answer_threshold = 1.05
authorized_answer_duplicates = [
    "All of these choices",
    "None of these choices",
    "All of the above",
    "None of the above",
    "All of the choices",
    "None of the choices",
]


def is_ok(d):
    return any(a.lower() in d["choices"][d["answer"]].lower() for a in authorized_answer_duplicates)


new_d_and_es = []
for i, (d, e, ae) in enumerate(d_and_es):
    if is_ok(d):
        new_d_and_es.append((d, e, ae))
        continue

    nearest_with_above_idx = max((ae @ ae2 for q, _, ae2 in d_and_es[:i] if not is_ok(q)), default=0)

    nearest_q_and_a_with_above_idx = max((e @ e2 + ae @ ae2 for q, e2, ae2 in d_and_es[:i] if not is_ok(q)), default=0)

    if nearest_with_above_idx < threshold and nearest_q_and_a_with_above_idx < with_answer_threshold:
        new_d_and_es.append((d, e, ae))
d_and_es = new_d_and_es

print(len(d_and_es), "after removing duplicates in answers")
# %%
# print 4 closest pairs in d_and_es_0
from tqdm import tqdm

all_pairs = [
    (e0 @ e1, d0["question"], d1["question"])
    for i, (d0, e0, _) in enumerate(tqdm(d_and_es_0))
    if not any(s.lower() in d0["question"].lower() for s in forbidden_strings)
    for j, (d1, e1, _) in enumerate(d_and_es_0)
    if i < j
]
all_pairs = sorted(all_pairs, reverse=True)
for i in range(20):
    for e in all_pairs[i]:
        print(e)
    print()
# %%
# print 4 closest pairs in d_and_es
from tqdm import tqdm


all_pairs = [
    (e0 @ e1, f"{d0['question']} answer: {d0['choices'][d0['answer']]}", f"{d1['question']} answer: {d1['choices'][d1['answer']]}")
    for i, (d0, e0, _) in enumerate(tqdm(d_and_es))
    if not any(s.lower() in d0["question"].lower() for s in forbidden_strings)
    for j, (d1, e1, _) in enumerate(d_and_es)
    if i < j
]
all_pairs = sorted(all_pairs, reverse=True)
for i in range(20):
    for e in all_pairs[i]:
        print(e)
    print()
#%%
# print 20 questions with forbidden string
for d, e, _ in d_and_es_0:
    if any(s.lower() in d["question"].lower() for s in forbidden_strings):
        print(d["question"])
    print("="*80)
# %%
random.Random(0).shuffle(d_and_es)
splits = 5
dev_set_size = 5

data = [d for d, _, _ in d_and_es]
data, dev_data = data[dev_set_size:], data[:dev_set_size]

for i in range(splits):
    json.dump(data[i::splits], open(f"data/wmdp-bio-deduped/split_{i}.json", "w"))
json.dump(dev_data, open("data/wmdp-bio-deduped/dev.json", "w"))
