# pip install google-generativeai openai anthropic

import os
from numpy import random
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm

load_dotenv()

from openai import OpenAI
import anthropic
import google.generativeai as genai


opanai_client = OpenAI()
anthropic_client = anthropic.Anthropic()
gemini_client = genai.GenerativeModel(
    model_name="gemini-1.5-pro",
    generation_config={
        "temperature": 0,
        "top_p": 0.95,
        "top_k": 64,
        "max_output_tokens": 256,
        "response_mime_type": "text/plain",
    },
)


def fireworks(text, model):
    import requests
    import json

    url = "https://api.fireworks.ai/inference/v1/chat/completions"
    payload = {
        "model": "accounts/fireworks/models/" + model,
        "max_tokens": 256,
        "top_p": 1,
        "top_k": 40,
        "presence_penalty": 0,
        "frequency_penalty": 0,
        "temperature": 0.6,
        "messages": [{"role": "user", "content": [{"type": "text", "text": text}]}],
    }
    headers = {
        "Accept": "application/json",
        "Content-Type": "application/json",
        "Authorization": "Bearer " + os.getenv("FIREWORKS_API_KEY"),
    }
    response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
    return json.loads(response.text)["choices"][0]["message"]["content"]


def query_gpt(text, model):

    try:
        if model.startswith("gpt-"):
            response = (
                opanai_client.chat.completions.create(
                    model=model,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": text,
                                }
                            ],
                        },
                    ],
                    temperature=0,
                    max_tokens=256,
                    top_p=1,
                    frequency_penalty=0,
                    presence_penalty=0,
                )
                .choices[0]
                .message.content
            )
        elif model.startswith("claude-"):
            response = (
                anthropic_client.messages.create(
                    model=model,
                    max_tokens=256,
                    temperature=0,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": text,
                                }
                            ],
                        }
                    ],
                )
                .content[0]
                .text
            )
        elif model == "gemini-1.5-pro":
            response = gemini_client.start_chat().send_message(text).text
        elif model.startswith("llama-"):
            response = fireworks(text, model)
        else:
            raise Exception("Invalid model: " + model)
    except Exception as e:
        print(e)
        return query_gpt(text, model)

    return response


def generate_problem(N, D, ordered=True):
    symbols = [
        "0123",
        "abcd",
        "efgh",
        "ijkl",
        "mnop",
        "qrst",
        "uvwx",
        "yzAB",
        "CDEF",
        "GHIJ",
        "KLMN",
        "OPQR",
        "STUV",
        "WXYZ",
    ]

    perm = [np.random.permutation(N) for _ in range(D)]

    equalities = []
    values = list(range(N))
    for i in range(D):
        p = list(range(N))
        random.shuffle(p)
        for j in p:
            if ordered or (i == 0 or random.random() < 0.5):
                equalities.append(f"{symbols[i + 1][j]} = {symbols[i][perm[i][j]]}")
            else:
                equalities.append(f"{symbols[i + 1][j]} = {symbols[i][perm[i][j]]}")

        values = [values[perm[i][j]] for j in range(N)]

    ix = np.random.choice(list(range(N)))

    query = "\n".join(equalities) + f"\nWhat is the value of {symbols[D][ix]}? Say directly only the numeric value, without any other words."
    answer = values[ix]

    return query, answer


N = 4
D = 13

fout = open("stats.txt", "a")

for model in [
    "llama-v3p1-405b-instruct",
    "gemini-1.5-pro",
    "gpt-4o-2024-08-06",
    "claude-3-5-sonnet-20240620",
]:
    fout.write(f"Model: {model}\n")
    fout.flush()

    for D in range(2, 14):
        correct = 0
        total = 400

        bar = tqdm(total=total)
        while bar.n < total:
            query, answer = generate_problem(N, D)
            print(query)
            print("Correct:", answer)

            answer_gpt = query_gpt(query, model)
            print("GPT:", answer_gpt)

            try:
                if int(answer_gpt) < 0 or int(answer_gpt) >= N:
                    print("Invalid answer")
                    raise Exception("Invalid answer")
                if int(answer_gpt) == answer:
                    correct += 1
                bar.update(1)
            except:
                print("Error parsing answer")
                pass

            print(f"\nCorrect: {correct}/{bar.n}")

        fout.write(f"D = {D}, N = {N}, {correct}/{total}\n")
        fout.flush()

    fout.write("\n")
