# eval/eval_identity.py
# just saves the selected canaries as-is, no optimization

import argparse
import logging
import pathlib

import dotenv
import torch
import numpy as np

from eval import data, settings, util


def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()
    log = logging.getLogger("eval_optimize_identity")

    config_path = util.DirectoryManager.get_config_path(args.dir)
    log.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())
    dm = util.DirectoryManager(args.dir)

    log.info("Identity canary export (no optimization)")
    log.info("Base dataset: %s", config.base_dataset.name)
    log.info("num_canaries: %d", config.num_canaries)

    can_img_path = dm.get_canaries_images_path()
    can_tgt_path = dm.get_canaries_targets_path()

    if args.resume and can_img_path.exists() and can_tgt_path.exists():
        log.info("Canary files already exist and --resume set. Skipping.")
        return

    # --- Load dataset ---
    loader = config.base_dataset.build_loader()
    loader.prepare_raw_data()
    train_images_full, train_targets_full = loader.load_train_data()
    data.validate_dataset(
        train_images_full,
        train_targets_full,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=config.base_dataset.get_num_train_samples(),
        num_classes=config.base_dataset.get_num_classes(),
    )

    # --- Pick indices (manual or seeded) ---
    log.info("Selecting canary indices (manual=%s)", args.manual_canary_selection is not None)
    canary_indices, _ = data.select_canary_indices(
        num_canaries=config.num_canaries,
        num_samples=train_images_full.shape[0],
        global_seed=config.global_seed,
        manual_selection=args.manual_canary_selection,
    )

    # --- Slice out canaries and save as-is (in [0,1]) ---
    canary_images = train_images_full[canary_indices].cpu()
    canary_targets = train_targets_full[canary_indices].cpu()

    C, H, W = config.base_dataset.get_image_shape()
    assert canary_images.shape == (config.num_canaries, C, H, W)
    assert canary_targets.shape == (config.num_canaries,)
    canary_images = torch.clamp(canary_images.to(torch.float32), 0.0, 1.0)
    canary_targets = canary_targets.to(torch.long)

    can_img_path.parent.mkdir(parents=True, exist_ok=True)
    can_tgt_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(canary_images, can_img_path)
    torch.save(canary_targets, can_tgt_path)
    log.info("Saved canary images -> %s", can_img_path)
    log.info("Saved canary targets -> %s", can_tgt_path)

    # Save indices alongside images/targets
    idx_path = can_img_path.parent / "selected_indices.pt"
    torch.save(torch.as_tensor(canary_indices, dtype=torch.long), idx_path)
    log.info("Saved selected canary indices -> %s", idx_path)

    log.info("Done.")


def parse_args() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dir",
        type=pathlib.Path,
        required=True,
        help="Path to experiment base directory",
    )
    parser.add_argument(
        "--manual-canary-selection",
        type=int,
        nargs="+",
        default=None,
        help="Manually select train-set indices to use as canaries (length must equal num_canaries).",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="If canary files already exist, do nothing.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main()