import argparse
import logging
import pathlib

import dotenv
import torch

import eval.data as data
import eval.settings as settings
import eval.util as util


def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()
    config_path = util.DirectoryManager.get_config_path(args.dir)
    logging.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())

    logging.info("Adding missing membership info for %s", args.dir)

    num_samples = config.base_dataset.get_num_train_samples()
    num_non_canaries = num_samples - config.num_canaries

    # Generate sample splits across all models
    membership_masks_targets, _ = data.generate_full_membership_masks(
        num_canaries=config.num_canaries,
        num_non_canaries=num_non_canaries,
        num_models_target=config.num_models_target,
        num_models_shadow=config.num_models_shadow,
        sample_non_canaries=config.sample_non_canaries,
        global_seed=config.global_seed,
    )
    util.validate_membership_masks(
        membership_masks_targets,
        num_canaries=config.num_canaries,
        num_non_canaries=num_non_canaries,
        num_models=config.num_models_target,
        sample_non_canaries=config.sample_non_canaries,
    )

    # Select only canary indices from membership masks
    # Canaries always come first in membership masks
    canary_membership = membership_masks_targets[:, :config.num_canaries]

    torch.save(canary_membership, args.dir / "canary_membership.pt")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", type=pathlib.Path, required=True, help="Path to experiment base directory")
    return parser.parse_args()


if __name__ == "__main__":
    main()
