import asyncio
import json
import math
import os
import pickle
from asyncio import Semaphore
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any, List, Union

import openai
import replicate
import torch
from langchain.llms import Replicate
from langchain.memory import ChatMessageHistory
from langchain.schema.messages import get_buffer_string
from openai import OpenAI
from openai._exceptions import RateLimitError
from tqdm import tqdm


def load_secrets(file_path):
    secrets = {}
    with open(file_path) as f:
        for line in f:
            key, value = line.strip().split("=", 1)
            secrets[key] = value
    return secrets


def add_api_env():
    secrets = load_secrets("../SECRETS")
    for k, v in secrets.items():
        os.environ[k] = v


add_api_env()
client = OpenAI()
executor = ThreadPoolExecutor(max_workers=10)
model_dict = {
    "llama2_70b": "meta/llama-2-70b:a52e56fee2269a78c9279800ec88898cecb6c8f1df22a6483132bea266648f00",
    "llama2_13b": "meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38",
    "llama2_7b": "meta/llama-2-7b:77dde5d6c56598691b9008f7d123a18d98f40e4b4978f8a72215ebfc2553ddd8",
}


def get_prompt(prompt_type: str) -> (str, List):
    prompt_map = {
        "paraphrase": """Can you please paraphrase the following New York Times headlines?
    They should still be in a similar headline format and reference the same event(s); the wording should just be different.
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object assign to the property named "headline" the original input headline and to the property "paraphrase" the paraphrased version of the headline.
    The resulting JSON object should be in this format: [{{"headline" : "string", "paraphrase": "string"}}]. \n {input_list}""",
        "altered": """I am going to pass you a list of New York Times headlines.
    I would like you to make small adjustments to them so they are no longer true headlines and do not refer to events that actually happened.
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object assign to the property named "headline" the original input headline and to the property "altered_headline" the altered version of the headline.
    The resulting JSON object should be in this format: [{{"headline": "string", "altered_headline": "string"}}]. \n {input_list}""",
        "future_events": """I will be passing a list of 20 events that are either about a scientific discovery, a celebrity passing away or someone being elected president".
    For each event, can you write a headline about it in the style of the New York Times?
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object assign to the property named "event" the original input event and to the property "headline" the headline you would expect to see in the New York Times.
    The resulting JSON object should be in this format: [{{"event": "string", "headline": "string"}}]. \n {input_list}""",
        "president": """Can you please write a New York Times style headline about each of the following people getting elected president?
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object assign to the property named "person" the original input person and to the property "headline" the headline you would expect to see in the New York Times.
    The resulting JSON object should be in this format: [{{"person": "string", "headline": "string"}}]. \n {input_list}""",
        "obituary": """Can you please write a New York Times style obituary headline about each of the following people passing away?
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object assign to the property named "person" the original input person and to the property "headline" the headline you would expect to see in the New York Times.
    The resulting JSON object should be in this format: [{{"person": "string", "headline": "string"}}]. \n {input_list}""",
        "discovery": """Can you please write a New York Times style headline about each of the following scientific breakthroughs?
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object assign to the property named "discovery" the original input discovery or event and to the property "headline" the headline you would expect to see in the New York Times.
    The resulting JSON object should be in this format: [{{"discovery": "string", "headline": "string"}}]. \n {input_list}""",
        "fiction": """Please generate 10 headlines in the style of the New York Times about key events that occurred in the book.
Please create a JSON object which enumerates a set of {n_input} child objects.
For each child object assign to the property named "book" the title of the input book and the property named "headline" to the generated headline.
The resulting JSON object should be in this format: [{{"book": "string", "headline": "string"}}]. \n {input_list}""",
        "time_rating": """For each of the following headlines, can you tell me what year they occurred? 
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object, assign to the property named "headline" the original input headline, to the property "predicted_year" your guess for the year when the headline occurred, and to the propert "reasoning", your reasoning for why you think this headline occurred when it did.
    If you do not know the exact year it occurred but have an idea about a year range (i.e. 2000-2022), you can return a year range. 
    If it is unclear and you do not know when this headline occurred, then for the "year" section of the json output, please return "unclear"
    The resuling JSON object should be in this format: [{{"headline" : "string", "predicted_year" : "string", "reasoning" : "string"}}]. \n {input_list}""",
        "llama_time_rating": """For each of the following headlines, can you tell me what year it occurred? 
    Please create a JSON object which enumerates a set of {n_input} child objects.
    For each child object, assign to the property named "headline" the original input headline, to the property "predicted_year" your guess for the year when the headline occurred, and to the propert "reasoning", your reasoning for why you think this headline occurred when it did.
    If it is unclear and you do not know when this headline occurred, then for the "year" section of the json output, please return "unclear"
    The resuling JSON object should be in this format: [{{"headline" : "string", "predicted_year" : "string", "reasoning" : "string"}}]. \n {input_list}""",
        "future_time_rating": """I have a number of New York Time's headlines that are from after July 2023. I am trying to find ones that have clear indicators that they occurred after July 2023  without explictly naming their publication dates. Headlines that have the strongest signal about when they were published refer to specific events that occurred in that timeframe. Can you give me a rating of how easy it would be to localize a headline in a period of time just by text alone? 
Please create a JSON object which enumerates a set of {n_input} child objects. For each child object, assign to the property named "headline" the original input headline, to the property "rating" your rating for how easy it is to localize a headline as occurring after July 2023, and to the property "reasoning", your reasoning for why you think this. 
Please return your output in the following json format [{{"headline" : "string", "rating" : "string", "reasoning" : "string"}}] \n {input_list}""",
    }

    input_field = {
        "paraphrase": "headline",
        "altered": "headline",
        "time_rating": "headline",
        "llama_time_rating": "headline",
        "future_time_rating": "headline",
        "future_events": "event",
        "president": "person",
        "obituary": "person",
        "discovery": "discovery",
        "fiction": "book",
    }

    output_fields = {
        "paraphrase": ["paraphrase"],
        "altered": ["altered_headline"],
        "time_rating": ["predicted_year", "reasoning"],
        "llama_time_rating": ["predicted_year", "reasoning"],
        "future_time_rating": ["rating", "reasoning"],
        "future_events": ["headline"],
        "president": ["headline"],
        "obituary": ["headline"],
        "discovery": ["headline"],
        "fiction": ["headline"],
    }

    return prompt_map[prompt_type], input_field[prompt_type], output_fields[prompt_type]


def get_replicate_completion(
    prompt: str,
    model: str,
    temperature: float = 0.75,
    top_p: float = 1,
    max_new_tokens: float = 20,
    min_new_tokens: float = 4,
) -> str:
    llm = Replicate(
        model=model,
        model_kwargs={
            "temperature": temperature,
            "top_p": top_p,
            "max_new_tokens": max_new_tokens,
            "min_new_tokens": min_new_tokens,
        },
    )
    return llm(prompt)


def get_llama_completion(messages, tokenizer, model, max_length=7000):
    """
    Generate text completions using the LLaMA-2 7B model.

    Parameters:
    - prompt (str): The input text prompt to generate completions for.
    - model_name (str): The model identifier for the LLaMA-2 7B model. Replace 'Llama-2-7B-model-name' with the actual model name/path.
    - max_length (int): The maximum length of the sequence to be generated.

    Returns:
    - completion (str): The generated text completion.
    """

    # Encode the prompt
    text = tokenizer.apply_chat_template(messages, tokenize=False)

    if torch.cuda.is_available():
        input_ids = tokenizer.encode(text, return_tensors="pt").to("cuda:0")
    else:
        input_ids = tokenizer.encode(text, return_tensors="pt").to("cpu")

    # Generate completion
    output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)

    # Decode the output

    completion = tokenizer.decode(
        output[0][len(input_ids[0]) :], skip_special_tokens=True
    )

    return completion


async def fetch_headline_ratings_async(
    prompt_type: str,
    semaphore: Semaphore,
    input_dict: dict,
    model: Union[str, Any],
    tokenizer: Any = None,
):
    base_prompt, input_field, output_fields = get_prompt(prompt_type)

    async with semaphore:
        errors = None
        input_list = "\n".join([k for k in input_dict.keys() if isinstance(k, str)])
        # print(input_list)
        prompt = base_prompt.format(n_input=len(input_dict), input_list=input_list)

        messages = [
            {
                "role": "system",
                "content": "You are a helpful assistant designed to output JSON.",
            },
            {"role": "user", "content": prompt},
        ]

        loop = asyncio.get_event_loop()

        if isinstance(model, str):
            if "gpt" in model:
                chat_completion = await loop.run_in_executor(
                    None,
                    lambda: openai.chat.completions.create(
                        model=model,
                        temperature=1.0,
                        messages=messages,
                    ),
                )

            elif "llama" in model:
                chat_completion = await loop.run_in_executor(
                    None,
                    lambda: replicate.run(
                        model,
                        {
                            "top_p": 1,
                            "prompt": prompt,
                            "system_prompt": "You are a helpful assistant designed to output JSON.",
                            "temperature": 1,
                            "max_new_tokens": 7000,
                            "min_new_tokens": 500,
                            "temperature": 1,
                        },
                    ),
                )

            model_answer = "".join(chat_completion)

        else:
            model_answer = await loop.run_in_executor(
                None,
                lambda: get_llama_completion(
                    messages=messages, tokenizer=tokenizer, model=model, max_length=7000
                ),
            )

        try:
            output_list = json.loads(
                model_answer[
                    str.find(model_answer, "[") : str.find(model_answer, "]") + 1
                ]
            )

            for output in output_list:
                results_dict = input_dict.get(output[input_field])
                for field in output_fields:
                    results_dict[field] = output[field]

                input_dict[output[input_field]] = results_dict

        except:
            errors = model_answer
            print(model_answer)
            input_dict = {}

    return input_dict, errors


async def save_checkpoint(output, checkpoint_path, errors, errors_path):
    # Asynchronously save the output to a file
    with open(checkpoint_path, "w") as f:
        json.dump(output, f)

    with open(errors_path, "wb") as pkl_file:
        pickle.dump(errors, pkl_file)


async def process_headlines(
    prompt_type,
    inputs,
    output_dict,
    error_catcher,
    model,
    thread_cap,
    batch_size,
    checkpoint_interval,
    output_path,
    errors_path,
    tokenizer=None,
):
    semaphore = asyncio.Semaphore(thread_cap)
    tasks = []
    n = len(inputs)
    steps = range(0, n, batch_size)
    pbar = tqdm(total=len(steps))

    for i in steps:
        task = asyncio.create_task(
            fetch_headline_ratings_async(
                prompt_type,
                semaphore,
                input_dict=dict(islice(inputs.items(), i, min(i + batch_size, n))),
                model=model,
                tokenizer=tokenizer,
            )
        )
        tasks.append(task)
        pbar.update(1)

        # Checkpointing
        if len(tasks) >= checkpoint_interval:
            results = await asyncio.gather(*tasks)
            tasks = []

            for result, errors in tqdm(results):
                try:
                    output_dict.update(result)
                except:
                    print(result)
                if errors is not None:
                    error_catcher.append(errors)

            print(f"Saving checkpoint with {len(output_dict)} entries to {output_path}")
            await save_checkpoint(output_dict, output_path, error_catcher, errors_path)

            pbar.update(checkpoint_interval)

    if tasks:
        try:
            completed_tasks = await asyncio.gather(*tasks)
            for task_result in completed_tasks:
                output_dict.update(task_result)
            await save_checkpoint(
                output_dict, output_path, error_catcher, errors_path
            )  # Final checkpoint save
        except:
            pass

    pbar.close()
    return output_dict, error_catcher


def fix_json_errors(errors):
    fixed_errors, remaining_errors = [], []
    for error in errors:
        fixed_error, remaining_error = fix_single_error(error)
        if fixed_error:
            fixed_errors.append(fixed_error)
        elif remaining_error:
            remaining_errors.append(remaining_error)
    print(len(fixed_errors), len(remaining_errors))
    return fixed_errors, remaining_errors


def fix_single_error(error):
    fixed_error, remaining_error = None, None
    try:
        fixed_error = json.loads(error[str.find(error, "[") : str.find(error, "]") + 1])
        return fixed_error, remaining_error

    except:
        # Missing commas between json objects
        if str.find(error, "},") == -1:
            error = str.replace(error, "}", "},")[:-1] + "]"

        # Incorrect starting json string
        if str.find(error, "```json") > -1:
            error = str.replace(error, "```json", "")

        # Incorrect ending quotes
        if str.find(error, "```") > -1:
            error = str.replace(error, "```", "")

        # Json object is not enclosed in list
        if str.find(error, "[") == -1:
            error = "[" + error + "]"

        # Json object has incorrect quotation marks
        if str.find(error, "‘") > -1:
            error = str.replace(error, "‘", "'")

        if str.find(error, "’") > -1:
            error = str.replace(error, "’", "'")

        if str.find(error, "//") > -1:
            error = error[: str.find(error, "//")] + "]"

        if "{\n{\n" in error:
            error = str.replace(error, "{\n{\n", "{\n")
        try:
            fixed_error = json.loads(error)
            return fixed_error, remaining_error

        except:
            remaining_error = error
            return fixed_error, remaining_error


def integrate_errors(
    fixed_errors: dict, headlines_dict: dict, new_headlines_dict: dict
) -> dict:
    total_len = len(new_headlines_dict)

    new_keys = [k for k in fixed_errors[0][0].keys() if k != "headline"]
    missed = 0
    for fixed_error in fixed_errors:
        for d in fixed_error:
            if "headline" in d:
                if isinstance(d["headline"], str):
                    if d["headline"] != "Court Ruling":
                        try:
                            new_headlines_dict[d["headline"]] = headlines_dict.get(
                                d["headline"]
                            )
                            for new_key in new_keys:
                                new_headlines_dict[d["headline"]][new_key] = d[new_key]
                        except:
                            # Add or remove period at end
                            try:
                                if d["headline"][-1] == ".":
                                    new_headlines_dict[d["headline"]] = (
                                        headlines_dict.get(d["headline"][:-1])
                                    )
                                    for new_key in new_keys:
                                        new_headlines_dict[d["headline"]][new_key] = d[
                                            new_key
                                        ]
                                else:
                                    new_headlines_dict[d["headline"]] = (
                                        headlines_dict.get(d["headline"] + ".")
                                    )
                                    for new_key in new_keys:
                                        new_headlines_dict[d["headline"]][new_key] = d[
                                            new_key
                                        ]
                            except:
                                # Weird quotations
                                fixed_quote_dict = {}
                                for k, v in headlines_dict.items():
                                    if isinstance(k, str):
                                        if d["headline"] in k:
                                            print("FOUND IN DICT", d["headline"])
                                            new_headlines_dict[d["headline"]] = v
                                            for new_key in new_keys:
                                                if new_key in d:
                                                    new_headlines_dict[d["headline"]][
                                                        new_key
                                                    ] = d[new_key]
                                            continue
                                        if "’" in k and "‘" in k:
                                            # print(k)
                                            # print(str.replace(str.replace(k, "’", "'"), "‘", "'"))
                                            fixed_quote_dict[
                                                str.replace(
                                                    str.replace(k, "’", "'"), "‘", "'"
                                                )
                                            ] = v
                                        elif "’" in k:
                                            fixed_quote_dict[
                                                str.replace(k, "’", "'")
                                            ] = v
                                        elif "‘" in k:
                                            fixed_quote_dict[
                                                str.replace(k, "‘", "'")
                                            ] = v
                                        elif "’" in k:
                                            fixed_quote_dict[
                                                str.replace(k, "’", "'")
                                            ] = v
                                        elif "  " in k:
                                            fixed_quote_dict[
                                                str.replace(k, "  ", " ")
                                            ] = v
                                try:
                                    new_headlines_dict[d["headline"]] = (
                                        fixed_quote_dict.get(d["headline"])
                                    )
                                    for new_key in new_keys:
                                        new_headlines_dict[d["headline"]][new_key] = d[
                                            new_key
                                        ]
                                except:
                                    # print(fixed_quote_dict)
                                    # print(d)
                                    missed += 1

        total_len += len(fixed_errors)

    # assert total_len == len(
    #    new_headlines_dict
    # ), "Incorrect number of headlines in the final dictionary"

    return new_headlines_dict
