import logging
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def load_model_and_tokenizer(model_path, dtype="float16"):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=getattr(torch, dtype),
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    return tokenizer, model


def prepare_prompt(template, optim_str, jailbreak=False):
    return template.replace("{optim_str}", optim_str) if jailbreak else template


def parse_tool_call(response_text):
    try:
        return json.loads(response_text.strip())
    except json.JSONDecodeError:
        logger.error("Failed to parse tool call.")
        return None


def extract_response_block(raw_output, model_path):
    if "llama" in model_path:
        try:
            return raw_output.split("<|python_tag|>")[1].split("<|eom_id|>")[0].strip()
        except IndexError:
            try:
                return raw_output.split("<|eom_id|>")[0].strip()
            except IndexError:
                logger.error("Failed to extract function call from model output.")
                return ""

    elif "granite" in model_path:
        try:
            return raw_output.split("<tool_call>")[1].split("<|end_of_text|>")[0].strip()
        except IndexError:
            logger.error("Failed to extract function call from model output.")
            return ""

    elif "mistral" in model_path:
        try:
            return raw_output.split("[TOOL_CALLS]")[1].split("]")[0].strip() + "]"
        except IndexError:
            try:
                return raw_output.split("[TOOL_CALLS]")[1].strip()
            except IndexError:
                logger.error("Failed to extract function call from model output.")
                return ""

    else:
        print(f"Unknown special tokens for the following model: {model_path}")
        return model_path


def display_tools(tools):
    logger.info("\nAvailable Tools:")
    for tool in tools:
        logger.info(f"- {tool['name']}: {tool['description']}")
