import re
import typing as tp

import ts_utils
from tree_sitter import Node, Tree
from ts_utils.parsing import load_grammar

import dataclasses
import textwrap


# ---------------------------------------------------------------------------- #
#                    tree_sitter based llm output extraction                   #
# ---------------------------------------------------------------------------- #
def find_first_syntax_error_node_index(raw_code: str, language="python"):
  """Extracts the index of the first syntax error node in the given code.

  Example:
  >>> source = '''
  def hello_world():
      print('hello world!')

  def incomplete_fn('''
  >>> end_index = find_first_syntax_error_node_index(source)
  >>> source[:end_index]
  '''
  def hello_world():
      print('hello world!')'''

  Args:
      raw_code: input code

  Returns:
      index of the first syntax error node in the given code. If no syntax error
      is found, returns the length of the code.
  """

  parser = ts_utils.make_parser(language)
  code_bytes = raw_code.encode()
  tree = parser.parse(code_bytes)
  cursor = tree.walk()
  cursor.goto_first_child()
  error_byte_index = None
  while not cursor.node.has_error:
    error_byte_index = cursor.node.end_byte
    if not cursor.goto_next_sibling():
      break

  if error_byte_index is None:
    error_byte_index = tree.root_node.end_byte

  trimmed_bytes = code_bytes[:error_byte_index].decode()
  return len(trimmed_bytes)


def find_entrypoint_index(raw_code: str, entrypoint_name: str):
  """Finds the index of the entrypoint in the given code.

  Takes into account the entrypoint's dependencies on other functions, but
  does not otherwise track references to other syntactic structures (e.g. classes,
  variables, etc.)

  Example:
  >>> source = '''
  def hello_world():
      print_hello_world()
  def print_random_fn():
      print('random')
  def print_hello_world():
      print('hello world!')
  def print_random_fn2():
      print('random')'''

  >>> index = find_entrypoint_index(source, 'hello_world')
  >>> source[:index]
  '''
  def hello_world():
      print_hello_world()
  def print_random_fn():
      print('random')
  def print_hello_world():
      print('hello world!')'''

  Args:
      raw_code: input code
      entrypoint_name: name of the entrypoint function

  Returns:
      index of the entrypoint function in the given code. If no entrypoint
      is found, returns the length of the code.
  """

  PYLANG = ts_utils.load_grammar("python")
  parser = ts_utils.make_parser("python")
  code_bytes = raw_code.encode()
  tree = parser.parse(code_bytes)
  top_level_defs = PYLANG.query(
    """
    (function_definition) @function.def
    """
  ).captures(tree.root_node)

  fn_name_nodes: list[Node | None] = [
    fn_def_node.child_by_field_name("name") for fn_def_node, _ in top_level_defs
  ]
  fn_names = [fn.text.decode() if fn else None for fn in fn_name_nodes]
  try:
    entrypoint_index = fn_names.index(entrypoint_name)
    entrypoint_node = top_level_defs[entrypoint_index][0]
    # get all the function calls in the entrypoint
    fn_calls = PYLANG.query(
      """
        (call
            function: (identifier) @function.call)"""
    ).captures(entrypoint_node)
    fn_call_names: list[str] = [fn_call.text.decode() for fn_call, _ in fn_calls]
    entrypoint_dependents = [
      fn
      for fn in fn_name_nodes
      if (fn is not None and fn.text.decode() in fn_call_names)
    ]
    entrypoint_dependent_end_bytes = [
      fn.parent.end_byte for fn in entrypoint_dependents if fn.parent is not None
    ]
    max_end_byte = max(entrypoint_dependent_end_bytes + [entrypoint_node.end_byte])
    return len(code_bytes[:max_end_byte].decode())
  except ValueError:
    return len(raw_code)


def extract_entrypoint_name_from_assertions(assertions: str | tp.Sequence[str]):
  """Extracts the expected entrypoint name from a test case."""

  if isinstance(assertions, str):
    assertions = [assertions]
  entrypoint_name_candidates = []
  PYLANG = ts_utils.load_grammar("python")
  for test_case in assertions:
    tree = ts_utils.parse(test_case, "python")
    captures = PYLANG.query(
      """
            (assert_statement
                (comparison_operator 
                    (call
                        function: (identifier) @function.call)))
        """
    ).captures(tree.root_node)
    for capture, _ in captures:
      entrypoint_name_candidates.append(capture.text.decode())
  assert entrypoint_name_candidates, "No entrypoint name candidates found"
  assert all(c == entrypoint_name_candidates[0] for c in entrypoint_name_candidates)
  return entrypoint_name_candidates[0]


def try_extract_code_from_chat_outputs(generation: str):
  pat = "```(.*)```"

  extracted_codes = re.findall(pat, generation, re.DOTALL | re.MULTILINE)
  if not extracted_codes:
    return generation
  return extracted_codes[0]


class OutputParserException(Exception):
  def __init__(self, message: str):
    super().__init__(message)


CODE_IN_MARKDOWN_REGEX = dict(
  codellama_instruct_tokens=r"\[PYTHON\](.*)\[/PYTHON]",
  code_in_markdown=r"```(?:[a-z]*\n)?(.*)```",
)


def extract_code_from_markdown(generation: str):
  """Extracts code from chat outputs like llama that are formatted
  as language strings

  ```(lang)?
  <code>
  ```
  """
  for _, pat in CODE_IN_MARKDOWN_REGEX.items():
    extracted_codes = re.findall(pat, generation, re.DOTALL | re.MULTILINE)
    if len(extracted_codes) > 0:
      return extracted_codes[0]
  raise OutputParserException("No code found")


def extract_until_first_top_level_error_node(code):
  code_bytes = code.encode()
  tree = ts_utils.parse(code_bytes, "python")
  if not tree.root_node.has_error:
    return code

  for toplevel_node in tree.root_node.children:
    if toplevel_node.has_error:
      trimmed_code_bytes = code_bytes[: toplevel_node.start_byte]
      reprased_tree = ts_utils.parse(trimmed_code_bytes, "python")

      if not reprased_tree.root_node.has_error:
        return trimmed_code_bytes.decode()
  raise OutputParserException("Could not find a top level error node")


def extract_first_function(code):
  """Trims the code to the end of the first function."""
  PYLANG = load_grammar("python")
  code_bytes = code.encode()
  tree = ts_utils.parse(code_bytes, "python")
  query = PYLANG.query(
    """
    (ERROR
        (function_definition) @f)
    (module
        (function_definition) @f)"""
  )
  captures = query.captures(tree.root_node)
  if len(captures) == 0:
    raise OutputParserException("No function found")
  node, _ = captures[0]
  return code_bytes[: node.end_byte].decode()


def extract_full_code_in_generation(code: str, *, first_function_only=False):
  code_bytes = code.encode()
  tree = ts_utils.parse(code_bytes, "python")

  if first_function_only:
    return extract_first_function(code)
  if tree.root_node.has_error:
    code_byte_end_idx = None
    # trim code to first top-level error node
    for toplevel_node in tree.root_node.children:
      if toplevel_node.has_error:
        trimmed_code_bytes = code_bytes[: toplevel_node.start_byte]
        reprased_tree = ts_utils.parse(trimmed_code_bytes, "python")

        if not reprased_tree.root_node.has_error:
          code_byte_end_idx = toplevel_node.start_byte
          break

    if code_byte_end_idx is None:
      # Unfixable error
      return code

    code_bytes = code_bytes[:code_byte_end_idx]
    tree = ts_utils.parse(code_bytes, "python")
    assert not tree.root_node.has_error

  PY = ts_utils.load_grammar("python")
  # Remove any assertions
  assert_captures = PY.query("(module (assert_statement) @assert)").captures(
    tree.root_node
  )
  stop_indices = [len(code_bytes)]
  if len(assert_captures) > 0:
    assert_node, _ = assert_captures[0]
    stop_indices.append(assert_node.start_byte)

  min_stop_idx = min(stop_indices)

  return code_bytes[:min_stop_idx].decode()


def passthrough_output_parser(generation: str):
  """Returns the generation as is"""
  return generation


# ---------------------------------------------------------------------------- #
#                           Normalize variables names                          #
# ---------------------------------------------------------------------------- #


def capture_variables_in_assignments(tree: Tree):
  variables = []

  def extract_identifiers(node: Node | None):
    if node is None:
      return []
    if node.type == "identifier":
      return [node]
    elif node.type == "attribute":
      return extract_identifiers(node.child_by_field_name("object"))
    elif node.type == "default_parameter":
      return extract_identifiers(node.child_by_field_name("name"))

    r = []
    for child in node.children:
      r.extend(extract_identifiers(child))
    return r

  def visit(node: Node):
    nonlocal variables
    if node.type in (
      "assignment",
      "augmented_assignment",
      "for_statement",
      "for_in_clause",
    ):
      for identifier_node in extract_identifiers(node.child_by_field_name("left")):
        variables.append(identifier_node)
    elif node.type == "as_pattern_target":
      for child in node.children:
        identifiers = extract_identifiers(child)
        # variables.extend(extract_identifiers(child))
        variables.extend(identifiers)
    elif node.type == "parameters":
      for child in node.children:
        identifiers = extract_identifiers(child)
        variables.extend(identifiers)
    else:
      for child in node.children:
        visit(child)

    return []

  visit(tree.root_node)

  return variables


def get_all_top_level_identifiers(node: Node | None):
  if not node:
    return []
  if node.type == "attribute":
    return get_all_top_level_identifiers(node.child_by_field_name("object"))

  if node.type == "identifier":
    return [node]
  if node.type == "call":
    return get_all_top_level_identifiers(node.child_by_field_name("arguments"))
  r = []
  for child in node.children:
    r.extend(get_all_top_level_identifiers(child))
  return r


def get_terminal_nodes_with_tokens(tree: Tree, source_bytes: bytes):
  nodes_with_tokens = []
  last_byte = 0
  for node in ts_utils.iter.iternodes(tree.walk()):
    if node.child_count > 0:
      continue

    text = source_bytes[last_byte : node.end_byte].decode()
    nodes_with_tokens.append((node, text))
    last_byte = node.end_byte
  return nodes_with_tokens


def normalize_varnames(source: str, parser_or_language) -> str:
  source_bytes = source.encode("utf-8")
  tree = ts_utils.parse(source_bytes, parser_or_language=parser_or_language)
  declarations = capture_variables_in_assignments(tree)
  nodes_with_tokens = get_terminal_nodes_with_tokens(tree, source_bytes)
  tokens = []
  remapped_names = {}
  for node in declarations:
    remapped_names[node.text.decode()] = f"V{len(remapped_names)}"
  for node, token in nodes_with_tokens:
    if node.type == "identifier" and node.text.decode() in remapped_names:
      tokens.append(
        token.replace(node.text.decode(), remapped_names[node.text.decode()])
      )
    else:
      tokens.append(token)
  return "".join(tokens)


def cached_normalizer(language: str):
  parser = ts_utils.parsing.make_parser(language)

  def normalizer(source: str):
    return normalize_varnames(source, parser)

  return normalizer


def extract_first_number(input_string):
  match = re.search(r"\d+\.\d+|\d+", input_string)
  if match:
    return match.group(0)
  else:
    raise OutputParserException("No number found")


def gsm8k_first_number_from_last_line(input_string):
  matches = re.findall(r"\d+\.\d+|\d+", input_string)
  if matches:
    return matches[-1]
  else:
    raise OutputParserException("No number found")


OUTPUT_PARSERS = {
  "extract_code_from_markdown": extract_code_from_markdown,
  "extract_until_first_top_level_error_node": extract_until_first_top_level_error_node,
  "extract_first_function": extract_first_function,
  "extract_full_code_in_generation": extract_full_code_in_generation,
  "passthrough_output_parser": passthrough_output_parser,
  "gsm8k_first_number_from_last_line": gsm8k_first_number_from_last_line,
}


T = tp.TypeVar("T")
S = tp.TypeVar("S")
F = tp.TypeVar("F")


@dataclasses.dataclass
class ParseSuccess(tp.Generic[T]):
  value: T
  parser: int
  """Index of the parser that succeeded"""


@dataclasses.dataclass
class ParseFailure(tp.Generic[F]):
  value: F
  error: Exception


ParseResult = tp.Union[ParseSuccess[T], ParseFailure[F]]
OutputParser = tp.Callable[[T], S]


def try_parse_output(
  inputs: tp.Sequence[T],
  *,
  output_parsers: tp.Sequence[OutputParser[T, S]],
  fallback: F,
  exception: tp.Type[Exception]
  | tuple[tp.Type[Exception], ...] = OutputParserException,
) -> list[ParseResult[S, F]]:
  """Extracts structured data from a sequence of inputs using a sequence of parsers.

  Args:
      inputs: sequence of inputs to parse.
      output_parsers: a sequence output parsers that will be tried in order.
      fallback: value to use if no parser succeeds.
      exception: the exception class or classes to catch when a parser fails. Other
          exceptions will be re-raised. Defaults to Exception.

  Returns:
     parsed outputs and a dictionary of parse errors indexed by input index.
  """
  results: list[ParseResult[S, F]] = []

  for input in inputs:
    for parser_idx, parser in enumerate(output_parsers):
      try:
        result = parser(input)
        results.append(ParseSuccess(value=result, parser=parser_idx))
        break
      except Exception as e:
        if isinstance(e, exception):
          if parser == output_parsers[-1]:  # If it's the last parser
            results.append(ParseFailure(value=fallback, error=e))
        else:
          raise  # Re-raise exceptions that don't match the specified exception types
    # else:
    # This will only be reached if no parser succeeded and no exception was raised
    #   results.append(fallback)

  return results
