# eval_attack.py

import argparse
import logging
import pathlib

import dotenv
import numpy as np
import sklearn.metrics
import torch

import eval.attack_util as attack_util
import eval.data as data
import eval.settings as settings
import eval.util as util


def _describe_model_from_config(cfg: settings.Settings) -> str:
    """
    Human-friendly model description based on the new unified optimizer block.
    Returns something like: "unrolled | arch=resnet9 kwargs={'num_classes': 10}".
    """
    try:
        if hasattr(cfg, "canaries") and getattr(cfg.canaries, "optimizer", None) is not None:
            opt = cfg.canaries.optimizer
            opt_type = getattr(opt, "optimizer_type", None)

            # Prefer explicit model_name/model_kwargs if present
            name = getattr(opt, "model_name", None)
            kw = getattr(opt, "model_kwargs", None)

            # Fall back to the legacy "architecture" dict shape if needed
            if (name is None or name == "") and getattr(opt, "architecture", None) is not None:
                arch = opt.architecture
                if isinstance(arch, dict):
                    name = arch.get("architecture_name", None)
                    kw = {k: v for k, v in arch.items() if k != "architecture_name"}
                else:
                    # pydantic model case
                    name = getattr(arch, "architecture_name", None)
                    kw = {
                        k: getattr(arch, k)
                        for k in dir(arch)
                        if not k.startswith("_") and k not in {"architecture_name"} and hasattr(arch, k)
                    }
            if name is None:
                name = "unknown"

            return f"{opt_type} | arch={name} kwargs={kw if kw is not None else {}}"
    except Exception as e:
        return f"(unable to resolve model info: {e})"

    return "(no optimizer info)"


def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()
    config_path = util.DirectoryManager.get_config_path(args.dir)
    logging.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())
    directory_manager = util.DirectoryManager(args.dir)

    logging.info(
        "Attacking %(num_target)d target and %(num_shadow)d shadow models",
        dict(num_target=config.num_models_target, num_shadow=config.num_models_shadow),
    )
    logging.info("Base dataset: %s", config.base_dataset.name)

    # Model/trainer info (compatible with new unified config where trainer may be absent)
    if getattr(config, "model_trainer", None) is not None:
        logging.info("Model trainer: %s", config.model_trainer.trainer_type)
    else:
        logging.info("Model trainer: (not specified; using optimizer block)")
    logging.info("Model (from optimizer): %s", _describe_model_from_config(config))

    device = torch.device("cpu")  # attacks are lightweight; CPU is fine and avoids GPU contention

    # -------------------------
    # Load canary targets
    # -------------------------
    canary_targets = torch.load(directory_manager.get_canaries_targets_path()).to(device)

    # -------------------------
    # Load shadow model predictions
    # -------------------------
    full_shadow_preds_file = directory_manager.get_full_shadow_predictions_file()
    if args.use_cache and full_shadow_preds_file.exists():
        logging.info("Loading cached shadow model predictions")
        shadow_predictions = torch.load(full_shadow_preds_file)
        assert shadow_predictions.shape == (
            config.num_canaries,
            config.num_models_shadow,
            config.base_dataset.get_num_classes(),
        )
    else:
        logging.info("Loading shadow model predictions (per-model)")
        preds = []
        for shadow_model_idx in range(config.num_models_shadow):
            pred_file = directory_manager.get_shadow_predictions_file(shadow_model_idx)
            d = torch.load(pred_file)
            current_predictions_canaries = d["pred_canaries"]
            preds.append(current_predictions_canaries)
        shadow_predictions = torch.stack(preds, dim=1).to(device)
        assert shadow_predictions.shape == (
            config.num_canaries,
            config.num_models_shadow,
            config.base_dataset.get_num_classes(),
        )
        # Cache combined
        full_shadow_preds_file.parent.mkdir(parents=True, exist_ok=True)
        torch.save(shadow_predictions, full_shadow_preds_file)
        logging.info("Cached shadow model predictions to %s", full_shadow_preds_file)

    # -------------------------
    # Load target model predictions
    # -------------------------
    full_target_preds_file = directory_manager.get_full_target_predictions_file()
    if args.use_cache and full_target_preds_file.exists():
        logging.info("Loading cached target model predictions")
        target_predictions = torch.load(full_target_preds_file)
        assert target_predictions.shape == (
            config.num_canaries,
            config.num_models_target,
            config.base_dataset.get_num_classes(),
        )
    else:
        logging.info("Loading target model predictions (per-model)")
        preds = []
        for target_model_idx in range(config.num_models_target):
            pred_file = directory_manager.get_target_predictions_file(target_model_idx)
            d = torch.load(pred_file)
            current_predictions_canaries = d["pred_canaries"]
            preds.append(current_predictions_canaries)
        target_predictions = torch.stack(preds, dim=1).to(device)
        assert target_predictions.shape == (
            config.num_canaries,
            config.num_models_target,
            config.base_dataset.get_num_classes(),
        )
        # Cache combined
        full_target_preds_file.parent.mkdir(parents=True, exist_ok=True)
        torch.save(target_predictions, full_target_preds_file)
        logging.info("Cached target model predictions to %s", full_target_preds_file)

    # -------------------------
    # LiRA attack
    # -------------------------
    logging.info("Performing LiRA attacks")
    scores = dict()

    # Prepare membership masks (IN/OUT) in the shape [num_canaries, num_models_*]
    membership_masks_targets, membership_masks_shadow = data.generate_full_canary_membership_masks(
        num_canaries=config.num_canaries,
        num_non_canaries=config.base_dataset.get_num_train_samples() - config.num_canaries,
        num_models_target=config.num_models_target,
        num_models_shadow=config.num_models_shadow,
        sample_non_canaries=config.sample_non_canaries,
        global_seed=config.global_seed,
    )
    membership_masks_targets = membership_masks_targets.T
    membership_masks_shadow = membership_masks_shadow.T
    assert membership_masks_targets.shape == (config.num_canaries, config.num_models_target)
    assert membership_masks_shadow.shape == (config.num_canaries, config.num_models_shadow)

    # Score functions to evaluate
    for score_name, score_fn in (("hinge", attack_util.hinge_score), ("logit", attack_util.logit_score)):
        logging.info("Using %s score", score_name)

        # Each returns [num_canaries, num_models_*]
        shadow_scores = score_fn(shadow_predictions, canary_targets)  # torch
        target_scores = score_fn(target_predictions, canary_targets)  # torch

        # LiRA expects an extra trailing dim (e.g., MC samples dim); add singleton
        shadow_scores = shadow_scores.unsqueeze(-1)
        target_scores = target_scores.unsqueeze(-1)

        # Core LiRA attack: output shape = [num_canaries, num_models_target]
        attack_scores = attack_util.lira_attack(
            target_scores=target_scores,
            shadow_scores=shadow_scores,
            shadow_membership_mask=membership_masks_shadow,
        )

        # -------------------------
        # ROC statistics
        # -------------------------
        assert membership_masks_targets.shape == attack_scores.shape

        y_true = membership_masks_targets.flatten().cpu().numpy()
        y_score = attack_scores.flatten().cpu().numpy()
        fprs, tprs, _ = sklearn.metrics.roc_curve(y_true=y_true, y_score=y_score)

        # TPR at 0.1% FPR (robust: interpolate if exact index missing)
        target_fpr = 0.001
        if fprs.size == 0 or tprs.size == 0:
            tpr_at = float("nan")
        else:
            # guarantee monotonic fprs for interpolation
            order = np.argsort(fprs)
            fprs_sorted = fprs[order]
            tprs_sorted = tprs[order]
            tpr_at = float(np.interp(target_fpr, fprs_sorted, tprs_sorted, left=tprs_sorted[0], right=tprs_sorted[-1]))

        logging.info("TPR at 0.1%% FPR: %.1f%%", tpr_at * 100 if np.isfinite(tpr_at) else float("nan"))
        logging.info("FPR resolution: %f%%", (2 / attack_scores.numel()) * 100)

        scores[score_name] = {
            "fprs": fprs,
            "tprs": tprs,
            "attack_scores": attack_scores.cpu().numpy(),
            "tpr_at_0p1pct_fpr": tpr_at,
        }

        # -------------------------
        # Global threshold baseline (no LiRA)
        # -------------------------
        tg = target_scores  # [C, T, 1]
        assert tg.shape[-1] == 1 and tg.shape[:-1] == membership_masks_targets.shape
        y_true_g = membership_masks_targets.flatten().cpu().numpy()
        y_score_g = tg.flatten().cpu().numpy()
        fprs_g, tprs_g, _ = sklearn.metrics.roc_curve(y_true=y_true_g, y_score=y_score_g)

        if fprs_g.size == 0 or tprs_g.size == 0:
            tpr_at_g = float("nan")
        else:
            order_g = np.argsort(fprs_g)
            fprs_g_sorted = fprs_g[order_g]
            tprs_g_sorted = tprs_g[order_g]
            tpr_at_g = float(np.interp(target_fpr, fprs_g_sorted, tprs_g_sorted, left=tprs_g_sorted[0], right=tprs_g_sorted[-1]))

        logging.info("Global threshold TPR at 0.1%% FPR: %.1f%%", tpr_at_g * 100 if np.isfinite(tpr_at_g) else float("nan"))

        scores[score_name].update(
            {
                "fprs_global": fprs_g,
                "tprs_global": tprs_g,
                "tpr_at_0p1pct_fpr_global": tpr_at_g,
            }
        )

    # -------------------------
    # Save everything
    # -------------------------
    results_file = directory_manager.get_attack_results_file()
    results_file.parent.mkdir(parents=True, exist_ok=True)
    results_dict = {
        f"{score_name}_{metric_name}": scores[score_name][metric_name]
        for score_name in scores
        for metric_name in scores[score_name]
    }
    np.savez(results_file, **results_dict)
    logging.info("Saved attack results to %s", results_file)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", type=pathlib.Path, required=True, help="Path to experiment base directory")
    parser.add_argument("--use-cache", action="store_true", help="Use cached predictions if available")
    return parser.parse_args()


if __name__ == "__main__":
    main()