import argparse
import configparser
import json
import os
import random
import time
from datetime import datetime
from typing import Dict, List

import requests
import trio
import trio_asyncio
from langchain.llms import Replicate
from langchain.schema.messages import get_buffer_string
from langchain_community.chat_message_histories import ChatMessageHistory
from openai import OpenAI
from tqdm import tqdm
from utils import add_api_env

os.chdir("../../")
home_dir = os.getcwd()
from setup_utils import add_api_env

add_api_env()

input_dir = os.path.join(home_dir, "future_probing/fcc/data/entities")
output_dir = os.path.join(home_dir, "future_probing/fcc/data/predictions")

model_dict = {
    "LLAMA2_70B": "meta/llama-2-70b:a52e56fee2269a78c9279800ec88898cecb6c8f1df22a6483132bea266648f00",
    "LLAMA2_13B": "meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38",
    "LLAMA2_7B": "meta/llama-2-7b:77dde5d6c56598691b9008f7d123a18d98f40e4b4978f8a72215ebfc2553ddd8",
    "gpt_3.5": "gpt-3.5-turbo-instruct",
    "gpt_4": "gpt-4-base",
}

add_api_env()


def get_completion(
    prompt: str,
    model: str,
    temperature: float = 0.75,
    top_p: float = 1,
    max_new_tokens: float = 100,
    min_new_tokens: float = 4,
) -> str:
    if "llama" in str.lower(model):
        llm = Replicate(
            model=model,
            model_kwargs={
                "temperature": temperature,
                "top_p": top_p,
                "max_new_tokens": max_new_tokens,
                "min_new_tokens": min_new_tokens,
            },
        )
        completion = llm(prompt)
        model_answer = "".join(completion)
        return model_answer

    elif "gpt" in str.lower(model):
        client = OpenAI()
        completion = client.completions.create(
            model=model,
            prompt=prompt,
            temperature=temperature,
        )
        model_answer = completion.choices[0].text
        return model_answer


async def process_case(limiter, entity, prompt, model, results, temperature):
    async with limiter:
        completion = await trio.to_thread.run_sync(
            get_completion, prompt, model, temperature
        )

    if entity not in results:
        results[entity] = []

    results[entity].append(completion)


async def get_completions_parallel(
    prompt_dict, model, num_reps=10, cap=50, temperature=0.75
):
    async with trio.open_nursery() as nursery:
        limiter = trio.CapacityLimiter(cap)
        results = {}

        for i in range(num_reps):
            for entity, prompt in prompt_dict.items():
                nursery.start_soon(
                    process_case, limiter, entity, prompt, model, results, temperature
                )
        time.sleep(10)
    return results


async def main(args):
    # Load Entity data to convert into prompts
    with open(os.path.join(input_dir, f"{args.topic}.json"), "r") as f:
        ex_dict = json.load(f)

    # Prompts are about certain people becoming president
    if args.topic in (["politicians", "fiction", "generic"]):
        prompts = {
            k: f"{k} is president of the United States. The year is"
            for k, v in dict(list(ex_dict.items())[:100]).items()
        }

    # Prompts are about actual previous presidents
    if args.topic == "presidents":
        prompts = {
            k: f"{k} is president of the United States. The year is"
            for k, v in dict(list(ex_dict.items())).items()
        }

    # Prompts are about celebrities future death dates
    if args.topic == "celebrities":
        prompts = {
            k: f"{k} has died. The year is"
            for k, v in dict(list(ex_dict.items())[:100]).items()
        }

    # Prompts are about scientific discoveries in the future
    if args.topic == "discoveries":
        prompts = {
            k: f"{k} has happened. The year is"
            for k, v in dict(list(ex_dict.items())[:100]).items()
        }

    # Prompts are about software versions in the future
    if args.topic == "software":
        prompts = {
            k: f"{v} is the most recent version of {k}. The year is"
            for k, v in dict(list(ex_dict.items())).items()
        }

    # Prompts about outlandish fictional events
    if args.topic == "outlandish_events":
        prompts = {i: f"{x}. The year is" for i, x in enumerate(ex_dict)}

    outputs = {}

    print(
        f"Getting results for {args.model} on {args.topic} with temperature {args.temperature}..."
    )
    outputs = await get_completions_parallel(
        prompts, model_dict[args.model], args.num_reps, args.cap
    )
    with open(
        os.path.join(
            output_dir,
            f"{args.topic}_{args.model}_temp{args.temperature}_preds.json",
        ),
        "w",
    ) as f:
        json.dump(outputs, f)

    return outputs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--topic", type=str)
    parser.add_argument("--model", type=str)
    parser.add_argument("--cap", type=int)
    parser.add_argument("--num_reps", type=int)
    parser.add_argument("--temperature", type=float)

    args = parser.parse_args()

    trio.run(main, args)
