# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import functools
import importlib
import inspect
from dataclasses import dataclass, field
from importlib.abc import SourceLoader
from textwrap import dedent, indent
from typing import Any, Callable, Generic, TypeVar, Union

from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")


def _to_code(func: Union["FunctionWithRequirements[T, P]", Callable[P, T], "FunctionWithRequirementsStr"]) -> str:
    if isinstance(func, FunctionWithRequirementsStr):
        return func.func

    code = inspect.getsource(func)
    # Strip the decorator
    if code.startswith("@"):
        code = code[code.index("\n") + 1 :]
    return code


@dataclass
class Alias:
    name: str
    alias: str


@dataclass
class ImportFromModule:
    module: str
    imports: list[Union[str, Alias]]


Import = Union[str, ImportFromModule, Alias]


def _import_to_str(im: Import) -> str:
    if isinstance(im, str):
        return f"import {im}"
    elif isinstance(im, Alias):
        return f"import {im.name} as {im.alias}"
    else:

        def to_str(i: Union[str, Alias]) -> str:
            if isinstance(i, str):
                return i
            else:
                return f"{i.name} as {i.alias}"

        imports = ", ".join(map(to_str, im.imports))
        return f"from {im.module} import {imports}"


class _StringLoader(SourceLoader):
    def __init__(self, data: str):
        self.data = data

    def get_source(self, fullname: str) -> str:
        return self.data

    def get_data(self, path: str) -> bytes:
        return self.data.encode("utf-8")

    def get_filename(self, fullname: str) -> str:
        return "<not a real path>/" + fullname + ".py"


@dataclass
class FunctionWithRequirementsStr:
    func: str
    _compiled_func: Callable[..., Any]
    _func_name: str
    python_packages: list[str] = field(default_factory=list)
    global_imports: list[Import] = field(default_factory=list)

    def __init__(self, func: str, python_packages: list[str] = [], global_imports: list[Import] = []):
        self.func = func
        self.python_packages = python_packages
        self.global_imports = global_imports

        module_name = "func_module"
        loader = _StringLoader(func)
        spec = importlib.util.spec_from_loader(module_name, loader)
        if spec is None:
            raise ValueError("Could not create spec")
        module = importlib.util.module_from_spec(spec)
        if spec.loader is None:
            raise ValueError("Could not create loader")

        try:
            spec.loader.exec_module(module)
        except Exception as e:
            raise ValueError(f"Could not compile function: {e}") from e

        functions = inspect.getmembers(module, inspect.isfunction)
        if len(functions) != 1:
            raise ValueError("The string must contain exactly one function")

        self._func_name, self._compiled_func = functions[0]

    def __call__(self, *args: Any, **kwargs: Any) -> None:
        raise NotImplementedError("String based function with requirement objects are not directly callable")


@dataclass
class FunctionWithRequirements(Generic[T, P]):
    func: Callable[P, T]
    python_packages: list[str] = field(default_factory=list)
    global_imports: list[Import] = field(default_factory=list)

    @classmethod
    def from_callable(
        cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = []
    ) -> "FunctionWithRequirements[T, P]":
        return cls(python_packages=python_packages, global_imports=global_imports, func=func)

    @staticmethod
    def from_str(
        func: str, python_packages: list[str] = [], global_imports: list[Import] = []
    ) -> FunctionWithRequirementsStr:
        return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)

    # Type this based on F
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        return self.func(*args, **kwargs)


def with_requirements(
    python_packages: list[str] = [], global_imports: list[Import] = []
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
    """Decorate a function with package and import requirements

    Args:
        python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
        global_imports (List[Import], optional): Required imports. Defaults to [].

    Returns:
        Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
    """

    def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
        func_with_reqs = FunctionWithRequirements(
            python_packages=python_packages, global_imports=global_imports, func=func
        )

        functools.update_wrapper(func_with_reqs, func)
        return func_with_reqs

    return wrapper


def _build_python_functions_file(
    funcs: list[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
) -> str:
    # First collect all global imports
    global_imports: set[str] = set()
    for func in funcs:
        if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
            global_imports.update(map(_import_to_str, func.global_imports))

    content = "\n".join(global_imports) + "\n\n"

    for func in funcs:
        content += _to_code(func) + "\n\n"

    return content


def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
    """Generate a stub for a function as a string

    Args:
        func (Callable[..., Any]): The function to generate a stub for

    Returns:
        str: The stub for the function
    """
    if isinstance(func, FunctionWithRequirementsStr):
        return to_stub(func._compiled_func)

    content = f"def {func.__name__}{inspect.signature(func)}:\n"
    docstring = func.__doc__

    if docstring:
        docstring = dedent(docstring)
        docstring = '"""' + docstring + '"""'
        docstring = indent(docstring, "    ")
        content += docstring + "\n"

    content += "    ..."
    return content
