# SPDX-License-Identifier: MIT
from __future__ import annotations
import argparse
import os
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from .train import train_net
from .adapt import ToxicityAdaptor, AdaptConfig
from .memory import MemoryStream
from .io_utils import append_attack_result_csv


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--advbench_csv", type=str, required=True)
    parser.add_argument("--csv_out", type=str, required=True)
    parser.add_argument("--json_out", type=str, required=True)
    parser.add_argument("--device", type=str, default="")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--max_tokens", type=int, default=2)
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    hf_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    config = AutoConfig.from_pretrained(args.model_name_or_path)
    device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu"))
    hf_model.to(device)

    # Replace with actual training data mapping word->float score.
    # The CSV path is required; no defaults are provided.
    df = pd.read_csv(args.advbench_csv)
    prompts = df["goal"].tolist()
    # Placeholder: users must provide train_data externally.
    # Here we construct a dummy small dict requiring user intervention for real runs.
    train_dataset = {"placeholder": 0.0}

    net = train_net(
        data=train_dataset,
        max_tokens=args.max_tokens,
        model=hf_model,
        tokenizer=tokenizer,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.lr,
        device=str(device),
        capture_mode="token",
        alpha=0.7,
        target_dtype=torch.float32,
        log_file="",
    )

    controller = ToxicityAdaptor(hf_model, tokenizer, net, net_device=str(device))
    cfg = AdaptConfig(u0=4.0, max_iters=20, eps=0.2, threshold=4.0, max_new_tokens=80, verbose_pipeline=False)
    memory = MemoryStream()

    import json
    for prompt in prompts:
        result = controller.adapt(prompt, cfg=cfg, memory=memory)
        append_attack_result_csv(csv_path=args.csv_out, malicious_prompt=prompt, ok=result.ok, output=result.output, mu=result.u)
    with open(args.json_out, "w", encoding="utf-8") as f:
        json.dump(memory.memory, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()