# -*- coding: utf-8 -*-
"""Functions and classes for the reasoning_DAG task

Assumption: in graph, key and "name" of each entry should be the same
"""

import time
import random
import string
import copy

import re
from typing import Optional, List
import numpy as np

from src.alg_agents import ProblemSolver


REQUEST_AGENT_OPTION = "speak"  # "speak" / "print_start_end"

TASK_DESCRIPTION = (
    "In this task, you will be asked to find out the numeric value of a "
    "specific variable, based on some given equations that describe the "
    "values of or relationship between multiple variables."
)


def generate_request_reasoning_DAG(config: dict) -> tuple:
    """Generate a request of the reasoning task"""

    def key_of_variable(row: int, col: int) -> str:
        return f"variable_{row}_{col}"

    depth, width, degree = config["depth"], config["width"], config["degree"]
    # ops = ["add", "subs"]
    # ops = ["max", "min"]
    ops = ["max", "min", "add", "subs"]

    print(
        f"\n\n/// task instance: depth = {depth}, "
        f"width = {width}, degree = {degree} ///\n\n",
    )

    if degree == 1:
        range_of_num = 1000
    else:
        range_of_num = 10  # to keep numbers small across the DAG

    # represent the task instance by graph
    graph = {}
    names = [
        "".join(
            np.random.choice(
                list(string.ascii_uppercase),
                size=6,
                replace=True,
            ),
        )
        for _ in range(width * depth * 2)
    ]
    names_dedup = []
    for name in names:
        if name not in names_dedup:
            names_dedup.append(name)

    idx_name = 0
    for row in range(width):
        for col in range(depth):
            key = key_of_variable(row, col)
            # name = "Variable_" + names_dedup[idx_name]
            # name = names_dedup[idx_name]
            name = key  # Make it easier to track the process
            graph[key] = {"name": name}
            idx_name += 1

    # value of each leaf variable
    for row in range(width):
        key = key_of_variable(row=row, col=depth - 1)
        val = np.random.randint(low=-range_of_num, high=range_of_num + 1)
        graph[key]["value"] = val
        graph[key]["clue"] = f"{graph[key]['name']} = {val}"
        graph[key]["is_leaf"] = True

    # args, op and value of each non-leaf variable, going backwards
    for col in range(depth - 2, -1, -1):
        rows_of_children = list(range(width))
        random.shuffle(rows_of_children)
        for row in range(width):
            if degree > 1:
                row_of_args = np.random.choice(
                    range(width),
                    size=degree,
                    replace=False,
                )
            else:
                row_of_args = [rows_of_children[row]]

            key_of_args = [key_of_variable(r, col + 1) for r in row_of_args]
            value_of_args = [graph[ka]["value"] for ka in key_of_args]

            op = np.random.choice(ops)
            if op == "add":
                val = sum(value_of_args)
                op_str = "+"
            elif op == "subs":
                val = int(value_of_args[0]) - np.sum(value_of_args[1:])
                op_str = "-"
            elif op == "max":
                val = max(value_of_args)
            elif op == "min":
                val = min(value_of_args)
            else:
                raise NotImplementedError(
                    f'Operation "{op}" is not supported.',
                )

            key = key_of_variable(row, col)
            graph[key]["value"] = val
            graph[key]["args"] = key_of_args
            graph[key]["op"] = op
            graph[key]["is_leaf"] = False

            if op in ("add", "subs"):
                parts_of_clue = [
                    f"{graph[key]['name']} = {graph[key_of_args[0]]['name']}",
                ]
                for ka in key_of_args[1:]:
                    parts_of_clue.append(
                        f" {op_str} {graph[ka]['name']}",
                    )
                graph[key]["clue"] = "".join(parts_of_clue)
            elif op in ("max", "min"):
                if len(key_of_args) >= 2:
                    args_concat = ", ".join(
                        [graph[ka]["name"] for ka in key_of_args],
                    )
                    clue = f"{graph[key]['name']} = {op}({args_concat})"
                else:
                    ka = key_of_args[0]
                    clue = f"{graph[key]['name']} = {graph[ka]['name']}"
                graph[key]["clue"] = clue
            else:
                raise NotImplementedError(
                    f'Operation "{op}" is not supported.',
                )

    row_of_target = np.random.randint(low=0, high=width)
    key_of_target = key_of_variable(row_of_target, 0)
    target_variable = graph[key_of_target]["name"]
    value_of_target = graph[key_of_target]["value"]
    question = f"What is the numeric value of {target_variable}?"
    true_solution = str(value_of_target)

    target_clues = [graph[key_of_target]["clue"]]
    keys_current = [key_of_target]
    for col in range(1, depth):
        keys_current_duplicated = list(
            np.concatenate([graph[key]["args"] for key in keys_current]),
        )
        keys_current = []
        for key in keys_current_duplicated:
            if key not in keys_current:
                keys_current.append(key)
        clues_current = [graph[key]["clue"] for key in keys_current]
        for clue in clues_current:  # dedup: seems unnecessary here?
            if clue not in target_clues:
                target_clues.append(clue)
        if col == depth - 1:
            num_leaf_clues = len(keys_current)

    return (
        target_variable,
        target_clues,
        num_leaf_clues,
        question,
        true_solution,
        graph,
    )


class ClueDB:
    """Class for database of clues, accessed via querying (tool use)"""

    def __init__(self, graph: dict) -> None:
        self.graph = copy.deepcopy(graph)

    def query_for_clue(self, variable: str) -> Optional[str]:
        """Query the database for a clue"""
        if variable not in self.graph:
            print(f"WARNING: {variable} is not found in graph, return None.")
            return None
        node = self.graph[variable]
        return node["clue"]


class ReasoningSolver(ProblemSolver):
    """Solver class for the reasoning_DAG task"""

    def __init__(self, config: dict) -> None:
        super().__init__(config=config)

    # --- solve with one LLM call, given list of all/target clues ---
    def solve_step_by_step(self, clues: List, question: str) -> dict:
        """Solve the reasoning task with one LLM call and
        "step-by-step" prompting (baseline/oracle)"""

        self.reset_cost_metrics()

        request_agent = self.spawn_request_agent()
        dialog_agent = self.spawn_dialog_agent()

        request_content = self.prompt_solve_step_by_step(clues, question)
        x_request = request_agent(
            content=request_content,
            option=REQUEST_AGENT_OPTION,
        )
        x = self.invoke_llm_call(x_request, dialog_agent)
        response = x.content
        solution = self.parse_solve_step_by_step(response)

        result = {
            "solution": solution,
            "response": response,
        }
        return result

    def prompt_solve_step_by_step(self, clues: List, question: str) -> str:
        """Prompter for solve_step_by_step"""
        clues_formatted = "\n".join(clues)
        prompt = (
            "You are an expert in math, logical thinking and reasoning.\n\n"
            f"[Task description]\n{TASK_DESCRIPTION}\n\n"
            f"[Question]\n{question}\n\n"
            "[Clues]\n"
            f"{clues_formatted}\n\n"
            f"[Question]\n{question}\n\n"
            "[Instructions]\n"
            "- Find the relevant information from the clues, "
            "think step by step, and answer the question.\n"
            "- Make sure that your response ends with the following format: "
            '"### The answer is {your numeric-value answer}. ###".'
        )
        return prompt

    def parse_solve_step_by_step(self, llm_response: str) -> str:
        """Parser for solve_step_by_step"""
        idx1, idx2 = llm_response.find("###"), llm_response.rfind("###")
        if idx1 == idx2:  # no or only one "###" found
            print("No feasible answer found in LLM's response.")
            return "null"

        llm_response = llm_response[idx1 + 3 : idx2]
        found_ints = re.findall(
            r"[-+]?\d+",
            llm_response,
        )  # integer, with optional +/- sign
        if len(found_ints) == 0:
            print("No feasible numeric answer found in LLM's response.")
            return "null"
        return found_ints[-1]

    # --- solve recursively ---
    def solve_recursively(self, clue_db: ClueDB, target_variable: str) -> dict:
        """Solve the task with recursive decomposition"""
        variables_found = {}
        clues_queried = []

        self.reset_cost_metrics()
        request_agent = self.spawn_request_agent()
        dialog_agent = self.spawn_dialog_agent()

        def process_node(variable: str) -> int:
            """Key function, defined recursively"""

            # Base case: variable has been calculated before
            if variable in variables_found:
                val = variables_found[variable]
                return val

            # Otherwise, solve it with LLM
            clue = clue_db.query_for_clue(variable)  # str or None

            if clue not in clues_queried:
                clues_queried.append(clue)

            # 1. Prompt LLM to find the answer or decompose the task
            prompt = (
                f"Your task is to find the numeric value of {variable} "
                "based on the following clues:\n\n"
                f"{clue}\n\n"
                f"[Instructions]\n"
                "- If you are confident that you have found the "
                f"numeric value of {variable}, please output your answer "
                "directly without anything else.\n"
                "- Otherwise, please output a list of variables, "
                "whose numeric values are necessary and sufficient "
                f"for calculating {variable}. "
                "For example, if your target variable is A while the only "
                'clue is "A = B + C", then your response should be '
                'the list ["B", "C"].'
            )
            x_request = request_agent(
                content=prompt,
                option=REQUEST_AGENT_OPTION,
            )
            x = self.invoke_llm_call(x_request, dialog_agent)
            llm_response = x.content

            # 2. Either answer is found, or need to
            # generate and solve children tasks
            idx1, idx2 = llm_response.rfind("["), llm_response.rfind("]")

            if (idx1 == -1) or (idx2 == -1):  # no list found
                found_ints = re.findall(
                    r"[-+]?\d+",
                    llm_response,
                )  # integer, with optional +/- sign
                if len(found_ints) == 0:
                    print(
                        "No feasible numeric answer found in LLM's response; "
                        "returning 0.",
                    )
                    val = 0  # TODO: is this appropriate?
                else:
                    val = eval(found_ints[-1])
                variables_found[variable] = val
                return val  # int

            try:
                segment = llm_response[idx1 : idx2 + 1]
                if ('"' in segment) or ("'" in segment):
                    children_variables = eval(segment)
                else:
                    children_variables = segment[1:-1].split(", ")
            except Exception:
                segment = llm_response[idx1 : idx2 + 1]
                print(
                    f"Unable to parse {segment} into children variables; "
                    "returning 0.",
                )
                val = 0
                variables_found[variable] = val
                return val

            values_of_children = {}
            for var in children_variables:
                values_of_children[var] = process_node(var)

            list_of_clues = [
                cvar + " = " + str(cval)
                for cvar, cval in values_of_children.items()
            ]
            clues_formatted = "\n".join([str(clue)] + list_of_clues)

            prompt_option = self.config["prompt_option"]
            if prompt_option == "answer_directly":
                prompt_for_answering = (
                    "Please output your answer directly without anything else."
                )
            elif prompt_option == "reason_step_by_step":
                prompt_for_answering = (
                    "Please do the calculation step by step, "
                    "and conclude with your final answer."
                )
            else:
                raise ValueError(
                    f'prompt_option "{prompt_option}" is not supported.',
                )
            prompt = (
                f"Your task is to find the numeric value of {variable} "
                "based on the following clues:\n\n"
                f"{clues_formatted}\n\n"
                f"{prompt_for_answering}"
            )
            x_request = request_agent(
                content=prompt,
                option=REQUEST_AGENT_OPTION,
            )
            x = self.invoke_llm_call(x_request, dialog_agent)
            llm_response = x.content

            found_ints = re.findall(
                r"[-+]?\d+",
                llm_response,
            )  # integer, with optional +/- sign
            if len(found_ints) == 0:
                print(
                    "No feasible numeric answer found in LLM's response; "
                    "returning 0.",
                )
                val = 0  # TODO: is this appropriate?
            else:
                val = eval(found_ints[-1])
            variables_found[variable] = val
            return val

        ans = process_node(target_variable)
        result = {
            "solution": ans,  # int
            "variables_found": variables_found,
            "clues_queried": clues_queried,
        }
        return result


def trial_reasoning_DAG(config: dict, seed: Optional[int] = None) -> dict:
    """One trial for one config"""

    # Generate random request (with controlled seed)
    random.seed(seed)
    np.random.seed(seed)
    (
        target_variable,
        target_clues,
        num_leaf_clues,
        question,
        true_solution,
        graph,
    ) = generate_request_reasoning_DAG(config)
    # clues = [graph[key]["clue"] for key in graph]
    random.seed(None)
    np.random.seed(None)

    # print("\n\n", "=" * 20, " full graph: ", "=" * 20, "\n")
    # for key in graph:
    #     print(key, ": ", graph[key])

    print("\n\n", "=" * 20, " question: ", "=" * 20)
    print(question)
    print("\n\n", "=" * 20, " true solution: ", "=" * 20)
    print(true_solution)

    # Set up ClueDB
    clue_db = ClueDB(graph)

    # Solve the problem and measure metrics
    time_start = time.time()
    solver = ReasoningSolver(config=config)
    solver.reset_cost_metrics()
    solve_method = config["solve_method"]
    if solve_method == "step_by_step":
        # result = solver.solve_step_by_step(clues, question)  # weak baseline
        result = solver.solve_step_by_step(
            target_clues,
            question,
        )  # "oracle" setting
    elif solve_method == "recursively":
        result = solver.solve_recursively(clue_db, target_variable)
    else:
        raise ValueError(f'solve_method "{solve_method}" is not supported.')
    latency = time.time() - time_start

    print("\n/// result of overall algorithm: ///")
    print(result)

    solution = result["solution"]
    solution = str(solution)

    if solution == "null":
        error_EM = 1.0
        error_abs = abs(eval(true_solution))  # TODO: is this appropriate?
    else:  # must be str of int
        error_EM = (
            0.0 if str(eval(solution)) == str(eval(true_solution)) else 1.0
        )
        error_abs = abs(eval(solution) - eval(true_solution))

    if "clues_queried" in result:
        clues_queried = result["clues_queried"]
        error_missed_coverage = calculate_clues_missed_coverage_error(
            clues_queried,
            target_clues,
        )
    else:
        error_missed_coverage = 0.0

    latency_ideal_parallel = latency
    latency_finite_parallel = latency

    trial_result = {
        "solution": solution,  # str
        "true_solution": true_solution,  # str
        "error_EM": float(error_EM),
        "error_abs": float(error_abs),
        "error_missed_coverage": float(error_missed_coverage),
        "latency": float(latency),
        "latency_ideal_parallel": float(latency_ideal_parallel),
        "latency_finite_parallel": float(latency_finite_parallel),
    }
    trial_result.update(solver.cost_metrics)
    trial_result.update(
        {
            "num_target_clues": len(target_clues),
            "num_leaf_clues": num_leaf_clues,
            "ideal_llm_calls": 2 * len(target_clues) - num_leaf_clues,
        },
    )

    return trial_result


def calculate_clues_missed_coverage_error(
    clues: List,
    target_clues: List,
) -> float:
    """Estimate missed coverage error of retrieved references"""
    score = 0.0
    for clue in target_clues:
        if clue in clues:
            score += 1.0
    error_missed_coverage = 1.0 - score / len(target_clues)
    return error_missed_coverage
