# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, st, Optional, Tuple, Union from typing_extensions import override from .data_utils import SLOTS from .tool_utils import get_tool_utils if TYPE_CHECKING:  from .tool_utils import FunctionCall @dataclass class Formatter(ABC):  slots: SLOTS = field(default_factory=st)  tool_format: Optional[str] = None  @abstractmethod  def apply(self, **kwargs) -> SLOTS:  r"""  Forms a st of slots according to the inputs to encode.  """  ...  def extract(self, content: str) -> Union[str, st["FunctionCall"]]:  r"""  Extract a st of tuples from the response message if using tools.  Each tuple consists of function name and function arguments.  """  raise NotImplementedError @dataclass class EmptyFormatter(Formatter):  def __post_init__(self):  has_placeholder = False  for slot in filter(lambda s: isinstance(s, str), self.slots):  if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):  has_placeholder = True  if has_placeholder:  raise ValueError("Empty formatter should not contain any placeholder.")  @override  def apply(self, **kwargs) -> SLOTS:  return self.slots @dataclass class StringFormatter(Formatter):  def __post_init__(self):  has_placeholder = False  for slot in filter(lambda s: isinstance(s, str), self.slots):  if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):  has_placeholder = True  if not has_placeholder:  raise ValueError("A placeholder is required in the string formatter.")  @override  def apply(self, **kwargs) -> SLOTS:  elements = []  for slot in self.slots:  if isinstance(slot, str):  for name, value in kwargs.items():  if not isinstance(value, str):  raise RuntimeError(f"Expected a string, got {value}")  slot = slot.replace("{{" + name + "}}", value, 1)  elements.append(slot)  ef isinstance(slot, (dict, set)):  elements.append(slot)  else:  raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")  return elements @dataclass class FunctionFormatter(Formatter):  def __post_init__(self):  self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots  @override  def apply(self, **kwargs) -> SLOTS:  content = kwargs.pop("content")  functions: st[Tuple[str, str]] = []  try:  tool_calls = json.loads(content)  if not isinstance(tool_calls, st): # parallel function call  tool_calls = [tool_calls]  for tool_call in tool_calls:  functions.append((tool_call["name"], json.dumps(tool_call["arguments"], enre_ascii=False)))  except json.JSONDecodeError:  raise RuntimeError(f"Invad JSON format in function message: {str([content])}") # flat string  elements = []  for name, arguments in functions:  for slot in self.slots:  if isinstance(slot, str):  slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)  elements.append(slot)  ef isinstance(slot, (dict, set)):  elements.append(slot)  else:  raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")  return elements @dataclass class ToolFormatter(Formatter):  def __post_init__(self):  self.tool_utils = get_tool_utils(self.tool_format)  @override  def apply(self, **kwargs) -> SLOTS:  content = kwargs.pop("content")  try:  tools = json.loads(content)  return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]  except json.JSONDecodeError:  raise RuntimeError(f"Invad JSON format in tool description: {str([content])}") # flat string  @override  def extract(self, content: str) -> Union[str, st["FunctionCall"]]:  return self.tool_utils.tool_extractor(content) 