import argparse
import concurrent.futures
import json
import logging
import os
import shutil
import sys

import tqdm
import yaml
import glob

#
# Make imports robust regardless of current working directory.
# This pipeline often runs `python main.py` from inside `multi-turn/`, but
# some modules live at the repo root (e.g. `defense/`, `unified_judge.py`).
#
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, _REPO_ROOT)

from multi_turn_generation.base import BasePlanGenerator
from data.harmbench import load_datasets

def build_plan_generator(name: str, args: dict) -> BasePlanGenerator:
    """
    Factory to instantiate a plan generator by name.
    Currently supports:
      - crescendo: multi-turn attack prompt generator (needs attacker model + config path)
    """
    name = name.lower()
    if name == "crescendo":
        from multi_turn_generation.crescendo import CrescendoPlanGenerator
        
        return CrescendoPlanGenerator(**args[name])
    elif name == "actorattack":
        from multi_turn_generation.actorattack import ActorAttackPlanGenerator

        return ActorAttackPlanGenerator(**args[name])
    
    elif name == 'xteaming':
        from multi_turn_generation.xteaming import XTeamingPlanGenerator

        return XTeamingPlanGenerator(**args[name])
    
    elif name == 'fitd':
        from multi_turn_generation.fitd import FITDPlanGenerator

        return FITDPlanGenerator(**args[name])
    
    elif name == 'coa':
        from multi_turn_generation.coa import CoAPlanGenerator

        return CoAPlanGenerator(**args[name])
    
    raise ValueError(f"Unknown plan generator '{name}'.")

def build_updater(name: str, args: dict):
    name = name.lower()
    if name == 'crescendo':
        from updater.crescendo import CrescendoUpdater
        return CrescendoUpdater(args)
    elif name == 'actorattack':
        from updater.actorattack import ActorAttackUpdater
        return ActorAttackUpdater(args)
    elif name == 'xteaming':
        from updater.xteaming import XTeamingUpdater
        return XTeamingUpdater(args)
    elif name == 'fitd':
        from updater.fitd import FITDUpdater
        return FITDUpdater(args)
    elif name == 'coa':
        from updater.coa import CoAUpdater
        return CoAUpdater(args)

def main():
    parser = argparse.ArgumentParser(description="Run updater loop with a plan generator.")
    parser.add_argument(
        "--plan_generator",
        type=str,
        default="crescendo",
        help="Plan generator to use (e.g., 'crescendo').",
    )
    parser.add_argument(
        "--updater",
        type=str,
        default="xteaming",
        help="Updater to use (e.g., 'xteaming').",
    )
    parser.add_argument(
        "--number_of_behaviors",
        type=int,
        default=None,
        help="Number of harmful behaviors to process from the dataset.",
    )
    parser.add_argument(
        "--max_workers",
        type=int,
        default=10,
        help="Max threads for running behaviors in parallel.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs",
        help="Directory to save aggregated results.",
    )
    parser.add_argument(
        "--output_name",
        type=str,
        default='Test',
        help="name used in output_dir",
    )
    parser.add_argument(
        "--updater_config",
        type=str,
        default=None,
        help="Path to updater config; defaults to ./config/updater/<updater>-config.yaml when not provided.",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Force re-run: clear existing output directory for this run before starting.",
    )
    parser.add_argument(
        "--disable-determinism",
        action="store_true",
        help="Disable determinism defaults in UnifiedLLMClient (sets ENABLE_DETERMINISM=false).",
    )
    parser.add_argument(
        "--enable-defense",
        action="store_true",
        dest="enable_defense",
        help="Enable mislead defense mechanism for response evaluation",
    )
    parser.add_argument(
        "--disable-defense",
        action="store_false",
        dest="enable_defense",
        help="Disable mislead defense mechanism",
    )

    parser.add_argument(
        "--enable-proact",
        action="store_true",
        dest="enable_proact",
        help="Enable ProAct proactive defense (spurious responses) in addition to / instead of mislead defense",
    )
    parser.add_argument(
        "--disable-proact",
        action="store_false",
        dest="enable_proact",
        help="Disable ProAct proactive defense",
    )

    parser.add_argument(
        "--enable-guard",
        action="store_true",
        dest="enable_guard",
        help="Enable Guard defense (LlamaGuard server): if unsafe, replace target response with a fixed refusal",
    )
    parser.add_argument(
        "--disable-guard",
        action="store_false",
        dest="enable_guard",
        help="Disable Guard defense",
    )
    
    args = parser.parse_args()

    # Determinism control (used by UnifiedLLMClient); keep judge determinism controlled by its own seed.
    if args.disable_determinism:
        os.environ["ENABLE_DETERMINISM"] = "false"
    else:
        # Default to enabled unless user explicitly sets env var already.
        os.environ.setdefault("ENABLE_DETERMINISM", "true")

    # Use output_name directly: if user specified, use it as-is; if auto-generated, it's already in full format
    args.output_dir = os.path.join(args.output_dir, args.output_name)
    
    if args.updater_config is None or (
        isinstance(args.updater_config, str) and args.updater_config.lower() == "none"
    ):
        args.updater_config = f"./config/updater/{args.updater}-config.yaml"
    
    logging.basicConfig(level=logging.INFO)

    with open(args.updater_config, "r") as f:
        updater_cfg = yaml.safe_load(f)
    args.max_turns = updater_cfg.get("max_turns")

    # Defense flag default: enabled unless explicitly disabled (matches previous Feedback behavior)
    if args.enable_defense is False:
        updater_cfg["enable_defense"] = False
    else:
        updater_cfg["enable_defense"] = True

    # ProAct flag default: disabled unless explicitly enabled
    if getattr(args, "enable_proact", None) is True:
        updater_cfg["enable_proact"] = True
    elif getattr(args, "enable_proact", None) is False:
        updater_cfg["enable_proact"] = False
    else:
        updater_cfg.setdefault("enable_proact", False)

    # Guard flag default: disabled unless explicitly enabled
    if getattr(args, "enable_guard", None) is True:
        updater_cfg["enable_guard"] = True
    elif getattr(args, "enable_guard", None) is False:
        updater_cfg["enable_guard"] = False
    else:
        updater_cfg.setdefault("enable_guard", False)

    # Preload similarity model once per process (before ThreadPoolExecutor starts).
    # This avoids N threads racing to load weights on first use and reduces tail latency.
    if updater_cfg.get("enable_defense", False):
        from defense.utils import get_similarity_model

        logging.info("Preloading similarity model for mislead defense...")
        get_similarity_model()
        logging.info("Similarity model ready.")

    data_df = load_datasets(number_of_behaviors = args.number_of_behaviors)

    # Implement --force here (main.py owns the output dir naming).
    if args.force and os.path.isdir(args.output_dir):
        logging.info("Force re-run enabled; deleting output dir: %s", args.output_dir)
        shutil.rmtree(args.output_dir, ignore_errors=True)
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Collect already processed behavior indices from existing files.
    existing_paths = glob.glob(os.path.join(args.output_dir, "behavior_*.json"))
    processed_idxs = []
    for p in existing_paths:
        stem = os.path.splitext(os.path.basename(p))[0]  # behavior_<idx>
        parts = stem.split("_", maxsplit=1)
        if len(parts) == 2:
            try:
                processed_idxs.append(int(parts[1]))
            except ValueError:
                continue
    logging.info("Found %d existing behavior files: %s", len(processed_idxs), sorted(processed_idxs))

    # Prepare parameters for each behavior
    all_param_dicts = []
    for idx, (_, row) in enumerate(data_df.iterrows()):
        if idx in processed_idxs:
            continue
        all_param_dicts.append(
            {
                "behavior_number": idx,
                "behavior_id": row.get("BehaviorID"),
                "harmful_behavior": row["Behavior"],
            }
        )

    def run_single_behavior(behavior_number, behavior_id, harmful_behavior):
        updater = build_updater(args.updater, updater_cfg)
        plan_gen = build_plan_generator(args.plan_generator, updater_cfg)
        
        result = updater.run_with_plan(
            plan_gen,
            harmful_behavior=harmful_behavior,
            max_turns=args.max_turns,
        )
        behavior_result = {
            "plan_generation": plan_gen.save_info,
            'behavior_number': behavior_number,
            "behavior_id": behavior_id,
            "behavior": harmful_behavior,
            "feedback_result": result,
        }
        # Save per-behavior result immediately so each thread persists its output
        behavior_path = os.path.join(args.output_dir, f"behavior_{behavior_number}.json")
        with open(behavior_path, "w") as f:
            json.dump(behavior_result, f, ensure_ascii=False, indent=2)
        logging.info(f"Saved behavior {behavior_number} results to {behavior_path}")

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        futures = {
            executor.submit(run_single_behavior, **pd): pd for pd in all_param_dicts
        }
        for future in tqdm.tqdm(
            concurrent.futures.as_completed(futures), total=len(futures)
        ):
            pdict = futures[future]
            behavior_number = pdict["behavior_number"]
            try:
                future.result()
            except Exception as e:
                logging.error(
                    f"Behavior {behavior_number} generated an exception", exc_info=e
                )

    logging.info("Finished")
    logging.info("All behavior results saved to %s", args.output_dir)


if __name__ == "__main__":
    main()
