import argparse
import pandas as pd
from utils import (
    load_model_and_tokenizer,
    prepare_prompt,
    parse_tool_call,
    extract_response_block,
    display_tools,
    logger,
)

import torch


def main(run_id, model_path, hijack=False, mcp=False, mcp_id="git"):
    
    # Load data
    if "llama" in model_path and mcp:
        df = pd.read_json(f"mcp/run_{mcp_id}.json").iloc[0]
    
    elif "llama" in model_path:
        df = pd.read_json(f"llama/run_{run_id}.json").iloc[0]

    elif "granite" in model_path:
        df = pd.read_json(f"granite/run_{run_id}.json").iloc[0]

    elif "mistral" in model_path:
        df = pd.read_json(f"mistral/run_{run_id}.json").iloc[0]

    else:
        print("Model not supported")

    tools = df["functions"]
    initial_prompt = df["prompt"]
    final_template = df["template"]
    tokens_optim_str = df["optim_str"]

    # Load model & tokenizer
    tokenizer, model = load_model_and_tokenizer(model_path)

    # Decode optim_str to include in the template
    optim_str = tokenizer.decode(tokens_optim_str, add_special_tokens=False)
    prompt = prepare_prompt(final_template, optim_str, hijack)

    # Display available tools
    display_tools(tools)

    logger.info("\nPrompt:\n%s", initial_prompt)

    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(model.device)
    input_len = input_ids.shape[1]

    outputs = model.generate(
        input_ids, do_sample=False, max_new_tokens=100, return_dict_in_generate=True, output_logits=True
    )
    generated_ids = outputs.sequences[:, input_len:]
    response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]

    logger.info("\n--------- FUNCTION CALL ----------\n%s", response_text)

    extracted = extract_response_block(response_text, model_path)
    parsed_result = parse_tool_call(extracted)

    logger.info("\n--------- PARSED FUNCTION CALL ----------\n%s", parsed_result)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference with a selected model and data sample.")
    parser.add_argument("--model_path", required=True, help="Path to the HuggingFace model.")
    parser.add_argument("--run_id", type=int, required=True, help="ID of the JSON run file to use.")
    parser.add_argument("--hijack", action="store_true", help="Whether to use hijack.")
    parser.add_argument("--mcp", action="store_true", help="If True, hijack MCPs example, else BFCL (MCP hijacks only available for Llama models)")
    parser.add_argument("--mcp_id", type=str, help="Hijack Git or Slack MCPs")

    args = parser.parse_args()

    main(
        run_id=args.run_id,
        model_path=args.model_path,
        hijack=args.hijack,
        mcp=args.mcp,
        mcp_id=args.mcp_id,
    )
