import ast_comments as ast
import difflib
import logging
from typing import List, Literal, Optional, Union

from python_utils.string_utils import get_markup_from_text

logger = logging.getLogger(__name__)


def get_python_code_from_text(text: str, max_cells: Optional[int] = 1) -> List[str]:
    cells = get_markup_from_text(text=text, markup=["python"])
    if max_cells is not None and len(cells) > max_cells:
        raise ValueError("More than %d code cells found in the text." % max_cells)
    return cells


def validate_python_code_global_variables(code: str, variables: List[str], func_names: list[str]):
    """check that code has no global variables not in `variables` and no constants"""
    try:
        tree = ast.parse(code)
    except SyntaxError as e:
        raise ValueError(f"Syntax error in the code: {e.msg} at line {e.lineno}")

    assert isinstance(tree, ast.Module)

    found_variables = set()
    found_constants = set()
    calls = []
    for item in tree.body:
        if not isinstance(item, ast.Expr):
            raise ValueError(f"Expected an Expr, got {type(item)} for {item}")

        item = item.value
        # sequence of function calls
        if not isinstance(item, ast.Call):
            raise ValueError(f"Expected a Call, got {type(item)} for {item}")

        # Extract arguments
        for arg in item.args:
            if isinstance(arg, ast.Constant):  # Constant values like numbers, strings
                found_constants.add(arg.value)
            elif isinstance(arg, ast.Name):  # Variable names
                found_variables.add(arg.id)

                # Extract keyword arguments (kwargs)
        for kw in item.keywords:
            if isinstance(kw.value, ast.Constant):  # Constant values
                found_constants.add(kw.value.value)
            elif isinstance(kw.value, ast.Name):  # Variable names
                found_variables.add(kw.value.id)

        calls.append(item)

    errors = []

    f_calls = [ast.unparse(f) for f in calls]

    if len(f_calls) == 1:
        for found_variable in found_variables:
            if found_variable not in variables:
                errors.append(
                    f"Found variable '{found_variable}' in the code. Only use the provided variables: {variables}."
                )

    found_strings = [c for c in found_constants if isinstance(c, str)]
    if len(f_calls) == 1 and len(found_strings) > 0:
        errors.append(
            f"When only using one call, you are not allowed to use constants. Found constants in the code: {found_constants}. Either change the constants to variables or suggest multiple calls."
        )

    func_signatures = set([f.func.id for f in calls])
    for func_call in func_signatures:
        if func_call not in func_names:
            closest_matches = difflib.get_close_matches(func_call, func_names)
            closest_matches = ", ".join(closest_matches)
            errors.append(
                f"Found function call {func_call} in the code, but it is not one of the predefined skills. Did you mean one of `{closest_matches}`?"
            )

    if len(errors) > 0:
        raise ValueError("The code contains the following errors:\n" + "\n".join(errors) + "")

    return f_calls


def get_top_level_definitions_from_code_str(code_str) -> dict[str, Union[Literal["class"], Literal["function"]]]:
    code = ast.parse(code_str)
    objects: dict[str, Union[Literal["class"], Literal["function"]]] = {
        n.name: "class" if isinstance(n, ast.ClassDef) else "function"
        for n in code.body
        if isinstance(n, (ast.FunctionDef, ast.ClassDef))
    }
    return objects


def remove_objects_from_code_str(code_str: str, objects: List[str]) -> str:
    # remove classes or functions from `code`
    code = ast.parse(code_str)

    cleaned_objects = []
    for defined_object in code.body:
        if isinstance(defined_object, ast.ClassDef) or isinstance(defined_object, ast.FunctionDef):
            if defined_object.name in objects:
                continue

        cleaned_objects.append(defined_object)

    cleaned_code = ast.unparse(ast.Module(body=cleaned_objects, type_ignores=[]))
    return cleaned_code


def remove_global_instructions(code_str) -> str:
    filtered_items = []
    tree = ast.parse(code_str)
    for item in tree.body:
        if isinstance(item, (ast.FunctionDef, ast.Import, ast.ImportFrom)):
            filtered_items.append(item)

    tree.body = filtered_items
    return ast.unparse(tree)


def remove_function_docstrings(code_str: str) -> str:
    """Remove docstrings from all function definitions in the given code string.

    Args:
        code_str: Python code as a string

    Returns:
        Modified code string with function docstrings removed
    """
    try:
        tree = ast.parse(code_str)
    except SyntaxError as e:
        logger.error(f"Syntax error in code: {e}")
        return code_str

    class DocstringRemover(ast.NodeTransformer):
        def visit_FunctionDef(self, node):
            # First apply the transformation to child nodes
            self.generic_visit(node)

            # Check if the function has a body and the first statement is a docstring
            if (
                node.body
                and isinstance(node.body[0], ast.Expr)
                and isinstance(node.body[0].value, ast.Constant)
                and isinstance(node.body[0].value.value, str)
            ):
                # Remove the docstring (first statement)
                node.body = node.body[1:]

                # If the function body is now empty, add a pass statement
                if not node.body:
                    node.body = [ast.Pass()]

            return node

        def visit_AsyncFunctionDef(self, node):
            # Handle async functions the same way
            self.generic_visit(node)

            if (
                node.body
                and isinstance(node.body[0], ast.Expr)
                and isinstance(node.body[0].value, ast.Constant)
                and isinstance(node.body[0].value.value, str)
            ):
                node.body = node.body[1:]

                if not node.body:
                    node.body = [ast.Pass()]

            return node

    # Apply the transformation
    remover = DocstringRemover()
    modified_tree = remover.visit(tree)

    # Convert back to code string
    return ast.unparse(modified_tree)


def get_function_arg_variables(code_str: str) -> List[str]:
    arg_vars = []
    for line in ast.parse(code_str).body:
        if not isinstance(line, ast.FunctionDef):
            continue
        for default in line.args.defaults:
            if not isinstance(default, ast.Name):
                continue
            arg_vars.append(default.id)

    return arg_vars
