"""
Script running the simple FH-attack

Takes a model, a payload (set of functions and prompt), and a gold function.
Return (save) the 
"""

import argparse
import json
import pandas as pd

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from code_nanoGCG.gcg_simple import ProbeSamplingConfig, GCGConfig, GCG
from code_nanoGCG.utils import generate_payload, generate_ground_truth

API_TOKEN = "your_hf_token"


def parse_args() -> argparse.Namespace:
    """
    Parse command-line arguments for FH-attack execution

    Returns
    -------
    argparse.Namespace
        A namespace containing all parsed command-line arguments.

    Command-line Arguments
    ----------------------
    --model-tokenizer : str, default="meta-llama/Llama-3.2-3B-Instruct"
        Path to Hugging Face tokenizer.

    --model-path : str, default="meta-llama/Llama-3.2-3B-Instruct"
        Path to Hugging Face model.

    --dataset-path : str, default="../data/original/BFCL_v3_multiple.json"
        Path to the dataset used for executing the FH-attack (BFCL_v3_multiple as default).

    --dataset-path-gt : str, default="../data/ground_truth/BFCL_v3_multiple_gt.json"
        Path to the ground-truth dataset to access gold functions given the default prompt.

    --index-to-test : int, default=0
        Index of the dataset sample to run on.

    --use-bfcl : bool, default=True
        Whether to use the prompt provided by the BFCL's payload (if not, use custom prompt).

    --prompt : str, default="description"
        Custom prompt to run the FH-attack, if not using the original prompt provided by the BFCL's payload.

    --attack-location : str, default="description"
        Specifies where to inject or test the attack in the input (only description supported).

    --device : str, default="cuda"
        The device to run the model on (e.g., "cuda", "cpu").

    --dtype : str, default="float16"
        Data type for model weights and inference (e.g., "float16", "bfloat16", "float32").

    --probe-sampling : flag
        If set, enables probe-based sampling mode.

    --saving-path : str or None, default=None
        Path to save experiment outputs and logs.

    --run-number : int, default=0
        Experiment run id when executing multiple runs.

    --optim-str-size : int, default=60
        Size of the optimization string size (number of adversarial tokens to be included).

    --expected-target : str, default="NaN"
        Expected target output (for monitoring the attack).

    --is-verbose : bool, default=False
        Whether to print verbose debug information during execution.

    --num-steps : int, default=500
        Number of epochs of the FH-attack.
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("--model-tokenizer", type=str, default="meta-llama/Llama-3.2-3B-Instruct")
    parser.add_argument("--model-path", type=str, default="meta-llama/Llama-3.2-3B-Instruct")
    parser.add_argument("--dataset-path", type=str, default="../data/original/BFCL_v3_multiple.json")
    parser.add_argument(
        "--dataset-path-gt",
        type=str,
        default="../data/ground_truth/BFCL_v3_multiple_gt.json",
    )
    parser.add_argument("--index-to-test", type=int, default=0)
    parser.add_argument("--use-bfcl", type=bool, default=True)
    parser.add_argument("--prompt", type=str, default="description")
    parser.add_argument("--attack-location", type=str, default="description")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--dtype", type=str, default="float16")
    parser.add_argument("--probe-sampling", action="store_true")
    parser.add_argument("--saving-path", type=str, default=None)
    parser.add_argument("--run-number", type=int, default=0)
    parser.add_argument("--optim-str-size", type=int, default=60)
    parser.add_argument("--expected-target", type=str, default="NaN")
    parser.add_argument("--is-verbose", type=bool, default=False)
    parser.add_argument("--num-steps", type=int, default=500)
    args = parser.parse_args()
    return args


def main():
    """
    Executing the FH-attack given the arguments parsed by the parse_args function.

    Returns
    -------
    Saved runs, including the full input context, the adversary tokens, the loss, and template for each epoch.
    """

    args = parse_args()

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        token=API_TOKEN,
        torch_dtype=getattr(torch, args.dtype),
        device_map="auto",
    )  # .to(args.device)
    tokenizer = AutoTokenizer.from_pretrained(args.model_tokenizer, token=API_TOKEN)

    probe_sampling_config = None
    if args.probe_sampling:
        draft_model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=getattr(torch, args.dtype)).to(
            args.device
        )
        draft_tokenizer = AutoTokenizer.from_pretrained(args.model_tokenizer)
        probe_sampling_config = ProbeSamplingConfig(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
        )

    # Load the payload from the dataset
    list_payload = generate_payload(args.dataset_path)
    list_ground_truth = generate_ground_truth(args.dataset_path_gt)

    # Test on one payload
    payload = list_payload[args.index_to_test]
    ground_truth = list_ground_truth[args.index_to_test]
    print("ground_truth:", ground_truth)

    # Loading all tools available
    tools = [json.loads(str(payload["functions_str"][i])) for i in range(len(payload["functions_str"]))]
    print("Available tools:", tools)

    list_tools = []
    list_function_name = []

    for x in tools:
        tool = x
        list_tools.append(tool)
        list_function_name.append(tool["name"])

    if args.use_bfcl:
        questions = [payload["query"]]

    else:
        questions = [args.prompt]

    print("prompt", questions)

    # Saving target_function_name (arbitrary selection of another function that is not ground_truth)
    # Saving its index in the function_str payload
    target_function_name = [name for name in list_function_name if name != ground_truth][0]
    print("target_function_name:", target_function_name)
    target_index = list_function_name.index(target_function_name)

    # All available functions, and the prompt
    messages = [{"functions_str": list_tools, "query": questions[i]} for i in range(len(questions))]

    if "llama" in args.model_path:

        modified_target = f'<|python_tag|>{{"name": "{target_function_name}",'
        print("target:", modified_target)

        expected_target = f'<|python_tag|>{{"name": "{ground_truth}",'

    elif "granite" in args.model_path:

        modified_target = f'<tool_call>[{{"name": "{target_function_name}",'
        print("target:", modified_target)

        expected_target = f'<tool_call>[{{"name": "{ground_truth}",'

    elif "mistral" in args.model_path:

        modified_target = f'[TOOL_CALLS] [{{"name": "{target_function_name}",'
        print("target:", modified_target)

        expected_target = f'[TOOL_CALLS] [{{"name": "{ground_truth}",'

    else:

        conversation = [
            {"role": "user", "content": messages[-1]["query"]},
        ]

        template = tokenizer.apply_chat_template(
            conversation=conversation,
            tools=[{"type": "function", "function": x} for x in messages[-1]["functions_str"]],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",  # Return as tensor directly
        ).to(model.device)
        input_length = template.shape[-1]

        with torch.no_grad():
            outputs = model.generate(
                input_ids=template,
                do_sample=False,
                max_new_tokens=100,
                return_dict_in_generate=True,
                output_scores=True,
            )

        generated_tokens = outputs.sequences[:, input_length:]
        expected_target = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

        modified_target = expected_target.replace(ground_truth, target_function_name)
        print("target:", modified_target)

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    save_config = []

    if args.optim_str_size == 35:
        optim_str_init = "x x x x x x x x x x " + "x x x x x x x x x x x x x x x x x x x x " + "x x x x x"

    elif args.optim_str_size == 10:
        optim_str_init = "x x x x x x x x x x"

    else:
        # Default set to 60
        optim_str_init = (
            "x x x x x x x x x x x x x x x x x x x x "
            + "x x x x x x x x x x x x x x x x x x x x " * 1
            + "x x x x x x x x x x x x x x x x x x x x"
        )

    save_config.append(
        {
            "model": args.model_path,
            "type": "mini",
            "index_to_test": args.index_to_test,
            "attack_location": args.attack_location,
            "target_function_name": target_function_name,
            "target": modified_target,
            "prompt": questions[0],
            "initial_str": args.optim_str_size,
            "add_space_before_target": False,
            "num_steps": args.num_steps,
        }
    )

    df = pd.DataFrame(save_config)

    df.to_json(args.saving_path + f"config_{args.run_number}.json", orient="records", indent=4)

    config = GCGConfig(
        verbosity="DEBUG",
        probe_sampling_config=probe_sampling_config,
        num_steps=args.num_steps,
        optim_str_init=optim_str_init,
    )

    gcg_instance = GCG(model, tokenizer, config)

    _ = gcg_instance.run(
        messages,  # Initial Payload
        modified_target,  # target_function_name, #Target
        target_index,  # Index location of the target in the payload
        args.attack_location,  # Attack location - description or name
        args.saving_path,  # path to save the logs
        args.run_number,  # run number
        expected_target,  # expected_target
        args.is_verbose,  # verbosity
    )


if __name__ == "__main__":
    main()
