# -*- coding: utf-8 -*-
"""Functions and classes for the reasoning_retrieval task

Notes:
- n, m are defined as number of clues, rather than tokens or characters
- to enable or disable "thought", keep or comment corresponding
  lines in prompt_summarize_and_answer

"""

import time
import random
import string

import re
from typing import Optional, List, Any
import numpy as np

from src.alg_agents import ProblemSolver
from src.utils import calculate_latency_parallel


REQUEST_AGENT_OPTION = "print_start_end"  # "speak" / "print_start_end"


def generate_request_reasoning_retrieval(config: dict) -> tuple:
    """Generate a request of the reasoning task
    (largely same as the one in reasoning_DAG.py)
    """

    def key_of_variable(row: int, col: int) -> str:
        return f"variable_{row}_{col}"

    depth, width, degree = config["depth"], config["width"], config["degree"]
    ops = ["add", "subs"]
    # ops = ["max", "min"]
    # ops = ["max", "min", "add", "subs"]

    if degree == 1:
        range_of_num = 1000
    else:
        range_of_num = 100

    # represent the task instance by graph
    graph = {}
    names = [
        "".join(
            np.random.choice(
                list(string.ascii_uppercase),
                size=6,
                replace=True,
            ),
        )
        for _ in range(width * depth * 2)
    ]
    names_dedup = []
    for name in names:
        if name not in names_dedup:
            names_dedup.append(name)

    idx_name = 0
    for row in range(width):
        for col in range(depth):
            key = key_of_variable(row, col)
            name = "Variable_" + names_dedup[idx_name]
            # name = names_dedup[idx_name]
            # name = key  # Make it easier to track the process of algorithm
            graph[key] = {"name": name}
            idx_name += 1

    # value of each leaf variable
    for row in range(width):
        key = key_of_variable(row=row, col=depth - 1)
        val = np.random.randint(low=-range_of_num, high=range_of_num + 1)
        graph[key]["value"] = val
        graph[key]["clue"] = f"{graph[key]['name']} = {val}"
        graph[key]["is_leaf"] = True

    # args, op and value of each non-leaf variable, going backwards
    for col in range(depth - 2, -1, -1):
        rows_of_children = list(range(width))
        random.shuffle(rows_of_children)
        for row in range(width):
            if degree > 1:
                row_of_args = np.random.choice(
                    range(width),
                    size=degree,
                    replace=False,
                )
            else:
                row_of_args = [rows_of_children[row]]

            key_of_args = [key_of_variable(r, col + 1) for r in row_of_args]
            value_of_args = [graph[ka]["value"] for ka in key_of_args]

            op = np.random.choice(ops)
            if op == "add":
                val = sum(value_of_args)
                op_str = "+"
            elif op == "subs":
                val = int(value_of_args[0]) - np.sum(value_of_args[1:])
                op_str = "-"
            elif op == "max":
                val = max(value_of_args)
            elif op == "min":
                val = min(value_of_args)
            else:
                raise NotImplementedError(
                    f'Operation "{op}" is not supported.',
                )

            key = key_of_variable(row, col)
            graph[key]["value"] = val
            graph[key]["args"] = key_of_args
            graph[key]["op"] = op
            graph[key]["is_leaf"] = False

            if op in ("add", "subs"):
                parts_of_clue = [
                    f"{graph[key]['name']} = {graph[key_of_args[0]]['name']}",
                ]
                for ka in key_of_args[1:]:
                    parts_of_clue.append(
                        f" {op_str} {graph[ka]['name']}",
                    )
                graph[key]["clue"] = "".join(parts_of_clue)
            elif op in ("max", "min"):
                if len(key_of_args) >= 2:
                    args_concat = ", ".join(
                        [graph[ka]["name"] for ka in key_of_args],
                    )
                    clue = f"{graph[key]['name']} = {op}({args_concat})"
                else:
                    ka = key_of_args[0]
                    clue = f"{graph[key]['name']} = {graph[ka]['name']}"
                graph[key]["clue"] = clue
            else:
                raise NotImplementedError(
                    f'Operation "{op}" is not supported.',
                )

    row_of_target = np.random.randint(low=0, high=width)
    key_of_target = key_of_variable(row_of_target, 0)
    target_variable = graph[key_of_target]["name"]
    value_of_target = graph[key_of_target]["value"]
    question = f"What is the numeric value of {target_variable}?"
    true_solution = str(value_of_target)

    # Collect all relevant/target clues
    target_clues = [graph[key_of_target]["clue"]]
    keys_current = [key_of_target]
    for col in range(1, depth):
        keys_current = list(
            np.concatenate([graph[key]["args"] for key in keys_current]),
        )
        clues_current = [graph[key]["clue"] for key in keys_current]
        for clue in clues_current:
            if clue not in target_clues:
                target_clues.append(clue)

    return target_variable, target_clues, question, true_solution, graph


class ReasoningSolver(ProblemSolver):
    """Solver class for the reasoning_retrieval task"""

    def __init__(self, config: dict) -> None:
        super().__init__(config=config)

    # --- solve directly with one LLM call ---
    def solve_directly(self, clues: List, question: str) -> dict:
        """Solve the reasoning task directly with one LLM call"""

        self.reset_cost_metrics()

        request_agent = self.spawn_request_agent()
        dialog_agent = self.spawn_dialog_agent()

        content = self.prompt_solve_directly(clues, question)
        x_request = request_agent(content=content, option=REQUEST_AGENT_OPTION)
        x = self.invoke_llm_call(x_request, dialog_agent)
        solution = self.parse_solve_directly(x.content)

        result = {
            "solution": solution,  # str
            # "summary": x.content,
            "num_passes": int(1),
        }
        return result

    def prompt_solve_directly(self, clues: List, question: str) -> str:
        """Prompter for solve_directly, i.e. answer the question in one shot"""
        clues_formatted = "\n".join(clues)
        prompt = (
            "Your task is to answer a question based on some given clues.\n\n"
            f"The question is: {question}\n\n"
            "[Clues]\n"
            f"{clues_formatted}\n\n"
            "[Instructions]:\n"
            "- Find the key information from the text, think step by step, "
            "and answer the question.\n"
            "- Make sure that your response concludes in the "
            'following format: "### The answer is {your answer}. ###".'
        )
        return prompt

    def parse_solve_directly(self, llm_response: str) -> str:
        """Parser for solve_directly"""
        idx1, idx2 = llm_response.find("###"), llm_response.rfind("###")
        if idx1 == idx2:  # no or only one "###" found:
            print("No feasible answer found in LLM's response.")
            return "null"
        llm_response = llm_response[idx1 + 3 : idx2]
        found_ints = re.findall(
            r"[-+]?\d+",
            llm_response,
        )  # integer, with optional +/- sign
        if len(found_ints) == 0:
            print("No feasible numeric answer found in LLM's response.")
            return "null"
        return found_ints[-1]

    # --- solve by multiple passes ---
    # self.config:
    #   - solve_method: "decomposition_cyclic" / "decomposition_parallel"
    #   - prompt_option: "answer_directly" / "reason_step_by_step"

    def solve_multi_passes(
        self,
        clues: List,
        question: str,
    ) -> dict:
        """Solve the reasoning task via multiple passes over the input;
        for each pass, chunks are processed sequentially or in parallel."""

        m = self.config["m"]  # number of clues per chunk
        max_passes = self.config["max_passes"]
        solve_method = self.config["solve_method"]

        n = len(clues)
        if (m > n) or (m <= 0):
            m = n
            print("WARNING: reset m = n.")

        self.reset_cost_metrics()
        request_agent = self.spawn_request_agent()
        dialog_agent = self.spawn_dialog_agent()

        # Chunking in a way that ensures each clue is covered twice
        chunks = []
        step_size = int(np.floor(m / 2))
        idx1, idx2 = -step_size, m - step_size
        # idx1, idx2 = 0, m
        while idx1 <= n - 1:
            chunks.append(clues[max(idx1, 0) : min(idx2, n)])
            idx1 += step_size
            idx2 += step_size
        num_chunks = len(chunks)

        print(
            f"\n\n/// n = {n}, m = {m}, k = {num_chunks}, "
            f"max_passes = {max_passes} ///\n\n",
        )

        # Outer loop: process the input text for multiple passes
        references_dedup = []  # list of retrieved clues, deduplicated
        thought = "None"
        solution = "None"
        done = False
        sub_latencies = []

        dict_process_one_pass = {
            "decomposition_cyclic": self.process_one_pass_cyclic,
            "decomposition_parallel": self.process_one_pass_parallel,
        }
        process_one_pass = dict_process_one_pass[solve_method]

        for i_pass in range(max_passes):
            len_ref_dedup = len(references_dedup)

            # Inner loop: process chunks for one pass
            result_one_pass = process_one_pass(
                question=question,
                references_dedup=references_dedup,
                thought=thought,
                chunks=chunks,
                request_agent=request_agent,
                dialog_agent=dialog_agent,
            )

            references_dedup = result_one_pass["references_dedup"]
            thought = result_one_pass["thought"]
            solution = result_one_pass["solution"]
            done = result_one_pass["done"]

            if solve_method == "decomposition_parallel":
                sub_latencies.extend(result_one_pass["sub_latencies"])

            if done:
                break  # answer has been found

            if len(references_dedup) == len_ref_dedup:
                break  # retrieve nothing new during this pass

        if done:
            print(f"/// Done within {i_pass + 1} passes. ///")
        else:
            print(f"/// Fail to finish within {i_pass + 1} passes. ")

        result = {
            "solution": solution,
            "references_dedup": references_dedup,
            "thought": thought,
            "num_passes": int(i_pass + 1),
        }
        if solve_method == "decomposition_parallel":
            result["sub_latencies"] = sub_latencies

        return result

    def process_one_pass_cyclic(
        self,
        question: str,
        references_dedup: List,
        thought: str,
        chunks: List,
        request_agent: Any,
        dialog_agent: Any,
    ) -> dict:
        """Process one pass over chunks sequentially"""

        for chunk in chunks:
            # retrieval with one LLM call
            content = self.prompt_generate_reference(
                question=question,
                references_dedup=references_dedup,
                thought=thought,
                chunk=chunk,
            )
            x_request = request_agent(
                content=content,
                option=REQUEST_AGENT_OPTION,
            )
            x = self.invoke_llm_call(x_request, dialog_agent)
            references_new = self.parse_generate_reference(x.content, chunk)
            for ref in references_new:
                if ref not in references_dedup:
                    references_dedup.append(ref)

        # update thought and try to answer
        content = self.prompt_summarize_and_answer(
            question=question,
            references_dedup=references_dedup,
            thought=thought,
        )
        x_request = request_agent(content=content)
        x = self.invoke_llm_call(x_request, dialog_agent)
        output = self.parse_summarize_and_answer(
            x.content,
        )  # dict with "solution", "done"
        output["references_dedup"] = references_dedup

        print("\n\n", "=" * 20, " output of parser: ", "=" * 20)
        print(output)

        return output

    def process_one_pass_parallel(
        self,
        question: str,
        references_dedup: List,
        thought: str,
        chunks: List,
        request_agent: Any,
        dialog_agent: Any,
    ) -> dict:
        """Process one pass over chunks in parallel"""

        # generate reference (retrieval) with one LLM call for each chunk
        sub_latencies_chunks = []
        references_new = []
        for chunk in chunks:  # in parallel
            time_start = time.time()
            content = self.prompt_generate_reference(
                question=question,
                references_dedup=references_dedup,
                thought=thought,
                chunk=chunk,
            )
            x_request = request_agent(
                content=content,
                option=REQUEST_AGENT_OPTION,
            )
            x = self.invoke_llm_call(x_request, dialog_agent)
            references_from_chunk = self.parse_generate_reference(
                x.content,
                chunk,
            )
            references_new.extend(references_from_chunk)
            sub_latencies_chunks.append(time.time() - time_start)

        for ref in references_new:
            if ref not in references_dedup:
                references_dedup.append(ref)

        # update thought and try to answer with one LLM call
        time_start = time.time()
        content = self.prompt_summarize_and_answer(
            question=question,
            references_dedup=references_dedup,
            thought=thought,
        )
        x_request = request_agent(content=content)
        x = self.invoke_llm_call(x_request, dialog_agent)
        output = self.parse_summarize_and_answer(x.content)
        sub_latencies = [sub_latencies_chunks, [time.time() - time_start]]
        output["sub_latencies"] = sub_latencies
        output["references_dedup"] = references_dedup

        print("\n\n", "=" * 20, " output of parser: ", "=" * 20)
        print(output)

        return output

    def prompt_generate_reference(
        self,
        question: str,
        references_dedup: List,
        thought: str,
        chunk: List,
    ) -> str:
        """Prompter for retrieving key information from the text chunk"""
        chunk_formatted = "\n".join(chunk)
        references_formatted = (
            "\n".join(references_dedup)
            if len(references_dedup) > 0
            else "None"
        )
        prompt = (
            "Your task is to retrieve relevant information from a "
            "piece of text for answering a question.\n\n"
            f"The question is: {question}\n\n"
            "[Clues]\n"
            f"{chunk_formatted}\n\n"
            "[Additional information]\n"
            f"{references_formatted}\n\n"
            # f"Hint: {thought}\n\n"
            "[Instructions]\n"
            "- Read the clues and additional information carefully, "
            "and retrieve **all but only** information from the clues "
            f'that is truly relevant to the question "{question}".\n'
            "- The results must be **exact copies** of the original clues.\n"
            "- Make sure that your response contains only a brief list of "
            "retrieved sentences, formatted as follows: "
            '["{the first retrieved clue as a string}", '
            '"{the second retrieved clue as a string}", ...]. '
            "You may return an empty list if no relevant clue is found."
        )

        return prompt

    def parse_generate_reference(self, llm_response: str, chunk: str) -> List:
        """Parser for retrieving key information from the text chunk"""
        idx1, idx2 = llm_response.rfind("["), llm_response.rfind("]")
        if (idx1 == -1) or (idx2 == -1):
            print("\n/// (no list is found in llm_response) ///\n")
            return []

        try:
            lst = eval(llm_response[idx1 : (idx2 + 1)])
            lst = [str(s) for s in lst]  # HOTFIX: mitigate "..."
        except Exception:
            print("\n/// (fail to eval list) ///\n")
            lst = []

        matched_pieces = []
        for piece in lst:
            if piece in chunk:
                matched_pieces.append(piece)
        return matched_pieces

    def prompt_summarize_and_answer(
        self,
        question: str,
        references_dedup: List,
        thought: str,
    ) -> str:
        """Prompter for summarization and reasoning"""
        prompt_option = self.config["prompt_option"]
        if prompt_option == "answer_directly":
            prompt_for_answering = (
                "Make sure that your response contains only a python "
                "dictionary in the following format, without anything else:"
            )
        elif prompt_option == "reason_step_by_step":
            prompt_for_answering = (
                "Make sure that your response contains a brief step-by-step "
                "reasoning process, and ends with a python dictionary in "
                "the following format:"
            )
        else:
            raise NotImplementedError(
                f"prompt_option {prompt_option} is not implemented.",
            )

        references_formatted = (
            "\n".join(references_dedup)
            if len(references_dedup) > 0
            else "None"
        )
        prompt = (
            "Your task is to aggregate some pieces of information "
            "and try to answer a question.\n\n"
            f"The question is: {question}\n\n"
            "[Clues]\n"
            f"{references_formatted}\n\n"
            # f"Hint: {thought}\n\n"
            "[Instructions]\n"
            "- Read the pieces of information carefully, "
            "think step by step, "
            "and decide whether the question can be answered "
            "based on the given information.\n"
            "- Try to answer the question based on the given information "
            "alone. Do not overthink too much.\n"
            "- Do not hallucinate. Answer the question only if you are "
            "absolutely confident.\n"
            f"- {prompt_for_answering}\n"
            "{\n"
            # '\t"thought": "{your brief thought about how to derive the '
            # 'answer from the given information, or what additional '
            # 'information need to be collected}",\n'
            '\t"solution": "{the final solution to the question as a string, '
            'or "None" if the given information is insufficient}",\n'
            '\t"done": {True or False (Bool), indicating whether the '
            "question is already solved or not}\n"
            "}"
        )
        return prompt

    def parse_summarize_and_answer(self, llm_response: str) -> dict:
        """Parser for summarization and reasoning"""

        idx1, idx2 = llm_response.rfind("{"), llm_response.rfind("}")

        try:
            sub_string = llm_response[idx1 : (idx2 + 1)]
            sub_string = sub_string.replace("false", "False")
            sub_string = sub_string.replace("true", "True")
            output = eval(
                sub_string,
            )  # {"solution": str, "done": bool}

            if isinstance(output["solution"], str) is False:
                output["solution"] = str(output["solution"])

            if isinstance(output["done"], bool) is False:
                if output["done"] == "True":
                    output["done"] = True
                else:
                    output["done"] = False

            found_digits = re.findall(r"[-+]?\d+", output["solution"])
            if len(found_digits) == 0:
                output["done"] = False

            if "thought" not in output:
                output["thought"] = "None"

        except Exception:
            output = {
                "thought": "None",
                "solution": "None",
                "done": False,
            }

        return output


def trial_reasoning_retrieval(
    config: dict,
    seed: Optional[int] = None,
) -> dict:
    """One trial for one config"""

    # Generate random request (with controlled seed)
    random.seed(seed)
    np.random.seed(seed)

    (
        _,
        target_clues,
        question,
        true_solution,
        graph,
    ) = generate_request_reasoning_retrieval(config)
    clues = [node["clue"] for node in graph.values()]
    random.shuffle(clues)

    random.seed(None)
    np.random.seed(None)

    print("\n\n", "=" * 20, " question: ", "=" * 20, "\n")
    print(question)
    print("\n\n", "=" * 20, " true solution: ", "=" * 20, "\n")
    print(true_solution)
    print("\n\n", "=" * 20, " target clues: ", "=" * 20, "\n")
    print(target_clues)

    # Solve the problem and measure metrics
    time_start = time.time()
    solver = ReasoningSolver(config=config)
    solver.reset_cost_metrics()
    solve_method = config["solve_method"]
    if solve_method == "directly":
        result = solver.solve_directly(clues, question)
    else:
        result = solver.solve_multi_passes(clues, question)
    latency = time.time() - time_start

    print("\n/// result of overall algorithm: ///")
    print(result)
    print("\n" + "=" * 60)

    solution = result["solution"]
    found_digits = re.findall(r"[-+]?\d+", solution)
    if len(found_digits) == 0:
        print("No digit sequence found in solution returned by algorithm.")
        solution = "None"
    else:
        solution = found_digits[-1]

    error_EM = 0.0 if solution == true_solution else 1.0
    if solution == "None":
        error_abs = abs(eval(true_solution))  # TODO: is this appropriate?
    else:  # must be str of int
        error_abs = abs(eval(solution) - eval(true_solution))
    error_missed_coverage = (
        calculate_references_missed_coverage_error(
            result["references_dedup"],
            target_clues,
        )
        if "references_dedup" in result
        else 0.0
    )

    if solve_method == "decomposition_parallel":
        sub_latencies = result["sub_latencies"]  # a list of lists
        latency_ideal_parallel = sum(max(lst) for lst in sub_latencies)
        latency_finite_parallel = sum(
            calculate_latency_parallel(lst, config["sim_parallel_degree"])
            for lst in sub_latencies
        )
    else:
        latency_ideal_parallel = latency
        latency_finite_parallel = latency

    trial_result = {
        "solution": solution,
        "true_solution": true_solution,
        "error_EM": float(error_EM),
        "error_abs": float(error_abs),
        "error_missed_coverage": float(error_missed_coverage),
        "latency": float(latency),
        "latency_ideal_parallel": float(latency_ideal_parallel),
        "latency_finite_parallel": float(latency_finite_parallel),
    }
    trial_result.update(solver.cost_metrics)

    if "num_passes" in result:
        num_passes = result["num_passes"]
        trial_result["num_passes"] = num_passes

    if "summary" in result:
        trial_result["summary"] = result["summary"]

    return trial_result


def calculate_references_missed_coverage_error(
    references: List,
    target_clues: List,
) -> float:
    """Estimate missed coverage error of retrieved references"""
    score = 0.0
    for clue in target_clues:
        if clue in references:
            score += 1.0
            print(
                "-" * 20,
                f' "{clue}" is found in references ',
                "-" * 20,
            )
    error_missed_coverage = 1.0 - score / len(target_clues)
    return error_missed_coverage
