# pip install google-generativeai openai anthropic

import os
from numpy import random
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm

load_dotenv()

from openai import OpenAI
import anthropic
import google.generativeai as genai


opanai_client = OpenAI()
anthropic_client = anthropic.Anthropic()
gemini_client = genai.GenerativeModel(
    model_name="gemini-1.5-pro",
    generation_config={
        "temperature": 0,
        "top_p": 0.95,
        "top_k": 64,
        "max_output_tokens": 256,
        "response_mime_type": "text/plain",
    },
)


def fireworks(text, model):
    import requests
    import json

    url = "https://api.fireworks.ai/inference/v1/chat/completions"
    payload = {
        "model": "accounts/fireworks/models/" + model,
        "max_tokens": 256,
        "top_p": 1,
        "top_k": 40,
        "presence_penalty": 0,
        "frequency_penalty": 0,
        "temperature": 0.6,
        "messages": [{"role": "user", "content": [{"type": "text", "text": text}]}],
    }
    headers = {
        "Accept": "application/json",
        "Content-Type": "application/json",
        "Authorization": "Bearer " + os.getenv("FIREWORKS_API_KEY"),
    }
    response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
    return json.loads(response.text)["choices"][0]["message"]["content"]


def query_gpt(text, model):

    try:
        if model.startswith("gpt-"):
            response = (
                opanai_client.chat.completions.create(
                    model=model,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": text,
                                }
                            ],
                        },
                    ],
                    temperature=0,
                    max_tokens=256,
                    top_p=1,
                    frequency_penalty=0,
                    presence_penalty=0,
                )
                .choices[0]
                .message.content
            )
        elif model.startswith("claude-"):
            response = (
                anthropic_client.messages.create(
                    model=model,
                    max_tokens=256,
                    temperature=0,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": text,
                                }
                            ],
                        }
                    ],
                )
                .content[0]
                .text
            )
        elif model == "gemini-1.5-pro":
            response = gemini_client.start_chat().send_message(text).text
        elif model.startswith("llama-"):
            response = fireworks(text, model)
        else:
            raise Exception("Invalid model: " + model)
    except Exception as e:
        print(e)
        return query_gpt(text)

    return response


def generate_retrieval_equation(N, D):
    symbols = [
        "0123",
        "abcd",
        "efgh",
        "ijkl",
        "mnop",
        "qrst",
        "uvwx",
        "yzAB",
        "CDEF",
        "GHIJ",
        "KLMN",
        "OPQR",
        "STUV",
        "WXYZ",
    ]

    perm = [np.random.permutation(N) for _ in range(D)]

    equalities = []
    values = list(range(N))
    for i in range(D):
        p = list(range(N))
        random.shuffle(p)
        for j in p:
            equalities.append(f"{symbols[i + 1][j]} = {symbols[i][perm[i][j]]}")

        values = [values[perm[i][j]] for j in range(N)]

    ix = np.random.choice(list(range(N)))

    query = "\n".join(equalities) + f"\nWhat is the value of {symbols[D][ix]}? Say directly only the numeric value, without any other words."
    answer = values[ix]

    return query, answer


def generate_retrieval_lives_with(N, D):
    names = "Alice Bob Charlie David Eve Frank Grace Henry Isabelle Jack Kate Larry Mary Nick Olivia Peter Queen Rose Sam Tom Ulysses Violet William Xander Yvonne Zach Adam Beth Carl Dana Edna Fred Gina Hank Iris Jake Kelly Liam Mona Nicky Ollie Penny Quin Remy Sally Tim Ursula Victor Wendy Xander Yvonne Zachary".split()
    places = "Amsterdam Berlin Cairo Delhi Edinburgh Florence Geneva Havana Istanbul Jerusalem Kyoto Lima Madrid Nairobi Oslo Paris Quebec Rome Sydney Tokyo Ulaanbaatar Vienna Warsaw Xian Yerevan Zurich".split()

    perm = [np.random.permutation(N) for _ in range(D)]

    statements = []
    values = list(range(N))
    for i in range(D):
        p = list(range(N))
        random.shuffle(p)
        for j in p:
            if i == 0:
                statements.append(f"{names[j]} lives in {places[perm[i][j]]}")
            else:
                statements.append(f"{names[j + N * i]} lives with {names[perm[i][j] + N * (i - 1)]}")

        values = [values[perm[i][j]] for j in range(N)]

    ix = np.random.choice(list(range(N)))
    name = names[ix + N * (D - 1)]
    place = places[values[ix]]

    query = "\n".join(statements) + f"\nWhere does {name} live? Say directly only the name of the city, without any other words."
    answer = place
    possible_answers = [places[i] for i in range(N)]

    return query, answer, possible_answers


def generate_retrieval_kingdoms(N, D):
    people = "Alice Bob Charlie David".split()
    cities = "Florinia Silvania Aurora Novaria".split()
    beliefs = "luminism harmonianism celestianism elysianism".split()
    foods = "beef pork chicken lamb".split()
    elements = "Nephryon Astralyte Virellium Zephyrium".split()
    effects = "Synthemia Aetherflux Somnosis Chronogy".split()

    perm = [np.random.permutation(N) for _ in range(D)]

    statements = []
    values = list(range(N))

    for i in range(D):
        p = list(range(N))
        random.shuffle(p)
        for j in p:
            if i == 0:
                statements.append(f"{people[perm[i][j]]} lives in {cities[j]}.")
            elif i == 1:
                statements.append(f"{cities[perm[i][j]]}ns believes in {beliefs[j]}.")
            elif i == 2:
                statements.append(f"{beliefs[perm[i][j]].capitalize().replace('ism', 'ists')} eat {foods[j]}.")
            elif i == 3:
                statements.append(f"{foods[perm[i][j]].capitalize()} contains {elements[j]}.")
            elif i == 4:
                statements.append(f"{elements[perm[i][j]].capitalize()} causes {effects[j]}.")
            else:
                raise Exception("Invalid D")
        # effects[j] -> elements[perm[4][j]] -> foods[perm[3][perm[4][j]]] -> beliefs[perm[2][perm[3][perm[4][j]]]] -> cities[perm[1][perm[2][perm[3][perm[4][j]]]]] -> people[perm[0][perm[1][perm[2][perm[3][perm[4][j]]]]]]

        values = [values[perm[i][j]] for j in range(N)]

    ix = np.random.choice(list(range(N)))

    query = "\n".join(statements) + f"\nWho has {[cities, beliefs, foods, elements, effects][D - 1][ix]}? Say directly the name without other words."
    answer = people[values[ix]]
    possible_answers = [people[i] for i in range(N)]

    return query, answer, possible_answers


def generate_retrieval_functions():
    funcs = [np.random.permutation(4) for _ in range(4)]
    letter_funcs = "abcd"
    letter_funcs_2 = "efgh"
    letters_vals = "ijkl"
    statements = []

    for i in range(4):
        for j in range(4):
            statements.append(f"{letter_funcs[i]}({j}) = {funcs[i][j]}")

    p_funcs = np.random.permutation(4)
    p_vals = np.random.permutation(4)

    for i in range(4):
        statements.append(f"{letter_funcs_2[i]} = {letter_funcs[p_funcs[i]]}")
    for i in range(4):
        statements.append(f"{letters_vals[i]} = {p_vals[i]}")

    ix1 = np.random.choice(4)
    ix2 = np.random.choice(4)

    query = "\n".join(statements) + f"\nWhat is the value of {letter_funcs_2[ix1]}({letters_vals[ix2]})? Say directly only the numeric value, without any other words."
    answer = str(funcs[p_funcs[ix1]][p_vals[ix2]])
    possible_answers = list(str(i) for i in range(4))

    return query, answer, possible_answers


def generate_retrieval_relatives():
    girl_names = "Alice Beth Cathy Dana Emma Fiona Grace Helen Isabelle Jane Kate Lily Mary Nancy Olivia Penny Queen Rose Sally Tina Ursula Violet Wendy Yvonne Zoe".split()
    boy_names = "Adam Bob Carl David Ed Fred George Hank Ike Jack Kevin Larry Mike Nick Oliver Peter Quin Remy Sam Tim Ulysses Victor William Xander Yancy Zach".split()
    country_names = "Argentina Brazil Canada Denmark England France Germany Hungary Italy Japan Kenya Laos Mexico Norway Peru Qatar Russia Spain Turkey Ukraine Venezuela Wales Yemen Zimbabwe".split()
    base_names = [
        ("John", "his"),
        ("Chris", "his"),
        ("Diana", "her"),
        ("Eve", "her"),
    ]
    relations = "mother sister father brother".split()
    jobs = "doctor lawyer teacher engineer".split()

    statements = []

    family = np.random.permutation(16).reshape(4, 4)
    locations = np.random.permutation(16).reshape(4, 4)
    statements_locations = []
    for i in range(4):
        for j in range(4):
            if j < 2:
                name = girl_names[family[i][j]]
            else:
                name = boy_names[family[i][j]]
            statements.append(f"{base_names[i][0]}'s {relations[j]} is {name}.")
            statements_locations.append(f"{name} lives in {country_names[locations[i][j]]}.")

    random.shuffle(statements_locations)
    statements = statements_locations + statements

    jobs_relatives = np.random.permutation(4)
    for i in range(4):
        statements.append(f"{jobs[i].capitalize()}s live with their {relations[jobs_relatives[i]]}s.")

    assigned_jobs = np.random.permutation(4)
    for i in range(4):
        statements.append(f"{base_names[i][0]} works as {'an' if jobs[assigned_jobs[i]][0] in 'aeiou' else 'a'} {jobs[assigned_jobs[i]]}.")

    ix = np.random.choice(list(range(4)))
    name = base_names[ix][0]

    query = "\n".join(statements) + f"\nWhere does {name} live? Say directly only the name, without any other words."
    answer = country_names[locations[ix][jobs_relatives[assigned_jobs[ix]]]]
    possible_answers = country_names[:16]

    return query, answer, possible_answers


N = 4
D = 5

model = "gpt-4o-2024-08-06"

correct = 0
total = 500

bar = tqdm(total=total)
while bar.n < total:
    # query, answer, possible_answers = generate_retrieval_lives_with(N, D)
    # query, answer, possible_answers = generate_retrieval_kingdoms(N, D)
    # query, answer, possible_answers = generate_retrieval_functions()
    query, answer, possible_answers = generate_retrieval_relatives()
    print("\n" + query)
    print("Correct:", answer)

    answer_gpt = query_gpt(query, model)
    print("GPT:", answer_gpt)

    try:
        if answer_gpt not in possible_answers:
            print("Invalid answer")
            raise Exception("Invalid answer")
        if answer_gpt == answer:
            correct += 1
        bar.update(1)
    except:
        print("Error parsing answer")
        pass

    print(f"\nCorrect: {correct}/{bar.n}")

print(f"\nD = {D}, N = {N}, {correct}/{total}\n")
