import json
import multiprocessing
import random
import time
from string import Template

import datasets
import langdetect
import requests
import tiktoken
from openai import OpenAI
from requests.exceptions import ChunkedEncodingError, RequestException
from semantic_text_splitter import TextSplitter
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


def inference_model(prompt, initial_delay=5, max_delay=60, max_retries=10):
    delay = initial_delay
    retries = 0
    while retries < max_retries:
        try:
            # print(f"{prompt=}")
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=8000,
                temperature=0.7,
                top_p=0.7,
                stop=["<|im_start|>", "<|im_end|>"],
                stream=True,
            )

            full_response = ""
            for chunk in response:
                if chunk.choices[0].delta.content is not None:
                    full_response += chunk.choices[0].delta.content
            return full_response
        except (ChunkedEncodingError, RequestException) as e:
            print(f"Error occurred: {e}. Retrying in {delay} seconds...")
            time.sleep(delay)
            delay = min(delay * 2, max_delay)
        except Exception as e:
            print(f"Unexpected error occurred: {e}")
            if "Unexpected error response None" in str(e):
                print("Received None response from API. Retrying...")
                time.sleep(delay)
                delay = min(delay * 2, max_delay)
            else:
                raise
        retries += 1

    print(f"Max retries ({max_retries}) reached. Unable to get a valid response.")
    return None


# Function to format input for the model
def format_input(user_input):
    tokenized_chat = tokenizer.apply_chat_template(
        [{"role": "user", "content": user_input}],
        tokenize=False,
        add_generation_prompt=True,
    )
    return tokenized_chat


# Filter dataset by token length
def filter_length(examples):
    res = []
    for text in examples["text"]:
        try:
            token_len = len(enc.encode(text))
        except Exception as e:
            res.append(False)
            continue
        if token_len < 60_000:
            res.append(False)
        elif token_len > 80_000:
            res.append(False)
        else:
            res.append(True)

    return res


def hierarchical_split(text):
    medium_chunks = medium_splitter.chunks(text=text)
    hierarchical_chunks = []
    for medium_chunk in medium_chunks:
        small_chunks = small_splitter.chunks(text=medium_chunk)
        hierarchical_chunks.append(
            {"medium_chunk": medium_chunk, "small_chunks": small_chunks}
        )
    return hierarchical_chunks


def summarize_chunk(chunk, word_limit):
    prompt = f"Summarize the following text concisely in no more than {word_limit} words:\n\n{chunk}"
    summary = inference_model(prompt)
    if "medium_chunk" in summary or "small_chunks" in summary:
        summary = summary.split(":")[-1].strip()
    words = summary.split()
    if len(words) > word_limit:
        summary = " ".join(words[:word_limit]) + "..."
    return summary


def summarize_hierarchical(hierarchical_chunks):
    medium_summaries = []

    for medium_chunk in hierarchical_chunks:
        print("Summarizing medium chunk")
        small_summaries = []
        for small_chunk in medium_chunk["small_chunks"]:
            summary = summarize_chunk(small_chunk, 75)
            small_summaries.append(summary)

        medium_summary = summarize_chunk("\n".join(small_summaries), 300)
        medium_summaries.append(
            {"medium_summary": medium_summary, "small_summaries": small_summaries}
        )

    full_summary = summarize_chunk(
        "\n".join([m["medium_summary"] for m in medium_summaries]), 800
    )

    return medium_summaries, full_summary


def generate_qa_pair(context, full_summary):
    prompt = f"""Given the following context and full summary of a book, generate a question-answer pair that relates to the full summary but can be answered better with knowledge from the given context.

Context: {context}

Full Summary: {full_summary}

Generate a question-answer pair in the following format:
{{
  "question": "<question>",
  "answer": "<answer>"
}}
"""
    response = inference_model(prompt)
    return json.loads(response)


def generate_specific_qa_pair(context):
    prompt = Template(
        """Context information is below.
${context}
Given the context information and not prior knowledge, generate content based on the below query.
You are a Teacher/Professor. Create 1 specific question about the details, events, characters, and settings from the context provided. This question should have an exact, unambiguous answer that can be directly found in the given information. The question should be similar in style to the following examples:

"Where does Character A meet Character B for the first time?"
"What is Character C's religion?"
"Where does Character D live for the majority of the story?"
"Which of the following is NOT one of Character E's responsibilities?"
"Which among [list of names] is not [Character F]'s child?"
"Who among [list of names] is the final to perish?"
"What's the name of [Family name]'s summer home?"
"Who accompanied [Character G] [specific activity] at last?"

Ensure that the question and answer are strictly based on the context information provided. The question may include multiple-choice options when appropriate.
You must return the result in JSON: {'question': <question>, 'answer': <answer>}"""
    )
    formatted_prompt = prompt.substitute(context=context)
    response = inference_model(formatted_prompt)
    print(f"Raw response: {response}")
    try:
        return json.loads(response)
    except json.JSONDecodeError as e:
        print(f"JSON decode error: {e}")
        import re

        json_match = re.search(r"\{.*\}", response, re.DOTALL)
        if json_match:
            try:
                return json.loads(json_match.group())
            except json.JSONDecodeError:
                pass


templates = [
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Your task is to set up 1 diverse temporal question about the context for an upcoming quiz/examination.
    The question should cover different time periods and events described in the context. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Your task is to create 1 character-based question from the context for an upcoming quiz/examination.
    The question should explore different aspects of the characters, such as their motivations, actions, and relationships. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    Formulate 1 complex question that requires analysis of multiple aspects from the context for an upcoming quiz/examination.
    The question should encourage critical thinking and synthesis of different pieces of information within the context. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Ask 1 question about the main themes or messages of the text for an upcoming quiz/examination.
    The question should cover different aspects of the themes and how they are developed in the context. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Create 1 question that compare different elements within the context for an upcoming quiz/examination.
    The question should highlight similarities and differences between various elements such as characters, events, and themes. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Develop 1 question that explore the cause and effect relationships within the context for an upcoming quiz/examination.
    The question should focus on understanding the reasons behind events and their outcomes. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Create 1 hypothetical question based on the context for an upcoming quiz/examination. The question should explore what-if scenarios and possible alternate outcomes. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Formulate 1 question that require interpretation of the context for an upcoming quiz/examination. The question should encourage students to provide their own insights and interpretations based on the information given. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Ask 1 detail-oriented question about the context for an upcoming quiz/examination. These question should focus on specific details, facts, and figures mentioned in the context. Restrict the question to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
    """Context information is below.
    ---------------------
    ${context}
    ---------------------
    Given the context information and not prior knowledge.
    Generate content based on the below query.
    You are a Teacher/Professor. Create 1 question that explore different perspectives or viewpoints within the context for an upcoming quiz/examination. The question should examine how different characters or groups might view events or themes differently. Restrict the questions to the context information provided.
    You must return the result in JSON: {'question': <question>, 'answer': <answer>}""",
]


def generate_one_qa_pair(chunks, question_type):
    if question_type == 0:  # Specific questions
        context = random.choice(chunks)
        prompt = Template(
            """Context information is below.
${context}
Given the context information and not prior knowledge, generate content based on the below query.
You are a Teacher/Professor. Create 1 specific question about the details, events, characters, and settings from the context provided. This question should have an exact, unambiguous answer that can be directly found in the given information. The question should be similar in style to the following examples:

"Where does Character A meet Character B for the first time?"
"What is Character C's religion?"
"Where does Character D live for the majority of the story?"
"Which of the following is NOT one of Character E's responsibilities?"
"Which among [list of names] is not [Character F]'s child?"
"Who among [list of names] is the final to perish?"
"What's the name of [Family name]'s summer home?"
"Who accompanied [Character G] [specific activity] at last?"

Ensure that the question and answer are strictly based on the context information provided. The question may include multiple-choice options when appropriate.
You must return the result in JSON: {'question': <question>, 'answer': <answer>}"""
        )
        formatted_prompt = prompt.substitute(context=context)

    elif question_type == 1:  # Multi-hop questions
        selected_chunks = random.sample(chunks, 3)
        prompt = Template(
            """Context information is below.
${selected_chunk_1}
${selected_chunk_2}
${selected_chunk_3}

You are a Professor designing a final exam for an advanced interdisciplinary course. Create 1 complex question requiring deep analysis and synthesis of information from all three chunks.
Do not mention that there are three chunks/your questions. Do not mention excerpts either.

For example, instead of a question that says
"Analyze the theme of justice and its various forms as portrayed in the three provided literary excerpts. How do the characters' actions and the outcomes of their situations reflect or challenge traditional notions of justice? Consider the legal, personal, and societal implications of justice in each excerpt and discuss the role of power dynamics in shaping justice."
You should say:
"Analyze the theme of justice and its various forms as portrayed. How do the characters' actions and the outcomes of their situations reflect or challenge traditional notions of justice? Consider the legal, personal, and societal implications of justice and discuss the role of power dynamics in shaping justice."

Question Guidelines:
1. The question must integrate and require reasoning across all three chunks.
3. Do not mention that there are three chunks/your questions. Do not mention excerpts either.

Answer Guidelines:
1. Reference and interconnect information from each chunk.
2. Your answer must be one paragraph.

Return 1 question-answer pair in JSON format:
{ "question": <question>, "answer": <answer> }
"""
        )
        formatted_prompt = prompt.substitute(
            selected_chunk_1=selected_chunks[0],
            selected_chunk_2=selected_chunks[1],
            selected_chunk_3=selected_chunks[2],
        )

    else:  # Normal questions
        context = random.choice(chunks)
        template = random.choice(templates)
        formatted_prompt = Template(template).substitute(context=context)

    response = inference_model(formatted_prompt)

    try:
        return json.loads(response)
    except json.JSONDecodeError:
        # If parsing fails, try to extract the JSON part from the response
        try:
            json_start = response.index("[")
            json_end = response.rindex("]") + 1
            json_str = response[json_start:json_end]
            return json.loads(json_str)
        except (ValueError, json.JSONDecodeError):
            # If extraction fails, return a default response
            print(f"Failed to parse response: {response}")
            return {
                "question": "Error generating question",
                "answer": "Error generating answer",
            }


def generate_extended_context(i):
    print(filtered_dataset)
    with open(f"/data/anno_1/long-gen/longbook/longdata-{i}", "w", encoding="utf-8") as f:
        # for data in filtered_dataset[i::num_processes]:
        for j in range(i, len(filtered_dataset), num_processes):
            try:
                initial_context = filtered_dataset[j]["text"]
                # print(f"{initial_context=}")
                hierarchical_chunks = hierarchical_split(initial_context)
                print("finished hierarchical split")

                medium_summaries, full_summary = summarize_hierarchical(
                    hierarchical_chunks
                )

                conversations = [
                    {
                        "role": "user",
                        "content": f"{initial_context} + Please give me a summary of the book",
                    },
                    {"role": "assistant", "content": f"{full_summary}"},
                ]

                qa_pairs = []
                last_medium_idx = None
                last_small_idx = None

                for i in range(25):  # Generate 25 QA pairs
                    print(i)
                    try:
                        if i == 0 or last_medium_idx is None:
                            print("first")
                            medium_idx = random.randint(0, len(medium_summaries) - 1)
                            qa_pair = generate_qa_pair(
                                medium_summaries[medium_idx]["medium_summary"],
                                full_summary,
                            )
                            last_medium_idx = medium_idx
                            last_small_idx = None
                        elif last_small_idx is None:
                            print("second")
                            small_idx = random.randint(
                                0,
                                len(
                                    medium_summaries[last_medium_idx]["small_summaries"]
                                )
                                - 1,
                            )
                            small_summary = medium_summaries[last_medium_idx][
                                "small_summaries"
                            ][small_idx]
                            medium_summary = medium_summaries[last_medium_idx][
                                "medium_summary"
                            ]
                            qa_pair = generate_qa_pair(small_summary, medium_summary)
                            last_small_idx = small_idx
                        else:
                            print("third")
                            choice = random.random()
                            print(choice)
                            if choice < 1 / 3:
                                print("fourth")
                                small_chunk = hierarchical_chunks[last_medium_idx][
                                    "small_chunks"
                                ][last_small_idx]
                                qa_pair = generate_specific_qa_pair(small_chunk)
                            elif choice < 2 / 3:
                                print("fifth")
                                new_small_idx = random.randint(
                                    0,
                                    len(
                                        medium_summaries[last_medium_idx][
                                            "small_summaries"
                                        ]
                                    )
                                    - 1,
                                )
                                small_chunk = hierarchical_chunks[last_medium_idx][
                                    "small_chunks"
                                ][new_small_idx]
                                medium_summary = medium_summaries[last_medium_idx][
                                    "medium_summary"
                                ]
                                qa_pair = generate_qa_pair(small_chunk, medium_summary)
                                last_small_idx = new_small_idx
                            else:
                                print("sixth")
                                new_medium_idx = random.randint(
                                    0, len(medium_summaries) - 1
                                )
                                while new_medium_idx == last_medium_idx:
                                    new_medium_idx = random.randint(
                                        0, len(medium_summaries) - 1
                                    )
                                qa_pair = generate_qa_pair(
                                    medium_summaries[new_medium_idx]["medium_summary"],
                                    full_summary,
                                )
                                last_medium_idx = new_medium_idx
                                last_small_idx = None

                        print(qa_pair)

                        # Check if qa_pair has the required keys
                        if "question" in qa_pair and "answer" in qa_pair:
                            qa_pairs.append(qa_pair)
                            conversations.extend(
                                [
                                    {"role": "user", "content": qa_pair["question"]},
                                    {"role": "assistant", "content": qa_pair["answer"]},
                                ]
                            )
                        else:
                            print(f"Skipping invalid QA pair: {qa_pair}")

                    except Exception as e:
                        print(f"Error generating QA pair: {e}")
                        continue  # Skip this iteration and continue with the next one

                small_chunks = small_splitter.chunks(text=initial_context)
                for j in range(50):
                    try:
                        question_type = random.randint(0, 2)
                        qa_pair = generate_one_qa_pair(small_chunks, question_type)
                        print(qa_pair)

                        qa_pairs.append(qa_pair)
                        conversations.extend(
                            [
                                {"role": "user", "content": qa_pair["question"]},
                                {"role": "assistant", "content": qa_pair["answer"]},
                            ]
                        )
                    except Exception as e:
                        print(f"Error generating QA pair: {e}")
                        continue  # Skip this iteration and continue with the next one

                # Save to file
                json_data = json.dumps({"conversations": conversations})
                f.write(json_data + "\n")
                f.flush()  # Force flush to disk
            except Exception as e:
                print(f"Error processing data entry: {e}")
                print("Skipping to next data entry.")
                continue


if __name__ == "__main__":
    # CUDA_VISIBLE_DEVICES=4,5,6,7 vllm serve Qwen/Qwen2-72B-Instruct --tensor-parallel-size 4 --download-dir /data/anno_1/long-gen/.cache/
    client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
    model = "Qwen/Qwen2-72B-Instruct"

    # # Initialize tiktoken encoder
    enc = tiktoken.encoding_for_model("gpt-4")

    path = "/data/anno_1/long-gen/books_data"
    dataset = datasets.load_dataset(
        path,
        "book",
        split="train",
        trust_remote_code=True,
        cache_dir="/data/anno_1/long-gen/.cache",
    )

    # path = "deepmind/pg19"
    # dataset = datasets.load_dataset(path, split="train", trust_remote_code=True, cache_dir="/data/anno_1/long-gen/.cache", num_proc=120)
    # print(f"Number of data entries in the dataset: {len(dataset)}")

    filtered_dataset = dataset.filter(filter_length, batched=True, num_proc=120)
    print(
        f"Number of data entries in the dataset after filtering length: {len(filtered_dataset)}"
    )

    # Filter for English language
    filtered_dataset = filtered_dataset.filter(
        lambda x: langdetect.detect(x["text"]) == "en", num_proc=120
    )
    print(
        f"Number of data entries in the dataset after filtering language: {len(filtered_dataset)}"
    )

    # Initialize text splitter
    medium_splitter = TextSplitter.from_tiktoken_model("gpt-4", capacity=12288)
    small_splitter = TextSplitter.from_tiktoken_model("gpt-4", capacity=4096)

    num_processes = 32
    processes = []

    # Create and start the processes
    for i in range(num_processes):
        p = multiprocessing.Process(target=generate_extended_context, args=(i,))
        p.start()
        processes.append(p)

    # Wait for all processes to complete
    for p in processes:
        p.join()
