import argparse
import configparser
import json
import os
import random
from datetime import datetime
from typing import Dict, List

import openai
import requests
import trio
import trio_asyncio
from langchain.llms import Replicate
from langchain.memory import ChatMessageHistory
from langchain.schema.messages import get_buffer_string
from tqdm import tqdm

os.chdir("../../")
home_dir = os.getcwd()
from setup_utils import add_api_env

add_api_env()
input_dir = os.path.join(home_dir, "future_probing/prompting/data/entities")
output_dir = os.path.join(home_dir, "future_probing/prompting/data/predictions")

model_dict = {
    "LLAMA2_70B": "meta/llama-2-70b:a52e56fee2269a78c9279800ec88898cecb6c8f1df22a6483132bea266648f00",
    "LLAMA2_13B": "meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38",
    "LLAMA2_7B": "meta/llama-2-7b:77dde5d6c56598691b9008f7d123a18d98f40e4b4978f8a72215ebfc2553ddd8",
}


def get_completion(
    prompt: str,
    model: str,
    temperature: float = 0.75,
    top_p: float = 1,
    max_new_tokens: float = 5000,
    min_new_tokens: float = 4,
) -> str:
    llm = Replicate(
        model=model,
        model_kwargs={
            "temperature": temperature,
            "top_p": top_p,
            "max_new_tokens": max_new_tokens,
            "min_new_tokens": min_new_tokens,
        },
    )
    return llm(prompt)


async def process_case(limiter, entity, prompt, model, results):
    print(prompt)
    async with limiter:
        completion = await trio.to_thread.run_sync(get_completion, prompt, model)

    if entity not in results:
        results[entity] = []

    print(completion)
    # year = completion[str.find(completion, "2") : str.find(completion, "2") + 4]
    results[entity].append(completion)


async def get_completions_parallel(prompt_dict, model, num_reps=10, cap=50):
    async with trio.open_nursery() as nursery:
        limiter = trio.CapacityLimiter(cap)
        results = {}

        for i in range(num_reps):
            for entity, prompt in prompt_dict.items():
                nursery.start_soon(
                    process_case, limiter, entity, prompt, model, results
                )
    return results


async def main(args):
    # Load Entity data to convert into prompts
    with open(os.path.join(input_dir, f"{args.topic}.json"), "r") as f:
        ex_dict = json.load(f)

    # Prompts are about certain people becoming president
    if args.topic in (["politicians", "fiction", "generic"]):
        prompts = {
            k: f"{k} is president of the United States. The year is"
            for k, v in dict(list(ex_dict.items())[:50]).items()
        }

    # Prompts are about actual previous presidents
    if args.topic == "presidents":
        prompts = {
            k: f"{k} is president of the United States. The year is"
            for k, v in dict(list(ex_dict.items())).items()
        }

    # Prompts are about celebrities future death dates
    if args.topic == "celebrities":
        prompts = {
            k: f"{k} has died. The year is"
            for k, v in dict(list(ex_dict.items())[:50]).items()
        }

    # Prompts are about scientific discoveries in the future
    if args.topic == "discoveries":
        prompts = {
            k: f"{k} has happened. The year is"
            for k, v in dict(list(ex_dict.items())[:50]).items()
        }

    # Prompts are about software versions in the future
    if args.topic == "software":
        prompts = {
            k: f"{v} is the most recent version of {k}. The year is"
            for k, v in dict(list(ex_dict.items())).items()
        }

    outputs = {}

    print(f"Getting results for {args.model}")
    outputs = await get_completions_parallel(
        prompts, model_dict[args.model], args.num_reps, args.replicate_cap
    )
    with open(
        os.path.join(output_dir, f"{args.topic}_{args.model}_preds.json"), "w"
    ) as f:
        json.dump(outputs, f)

    return outputs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--topic", type=str)
    parser.add_argument("--model", type=str)
    parser.add_argument("--replicate_cap", type=int)
    parser.add_argument("--num_reps", type=int)

    args = parser.parse_args()

    trio.run(main, args)
