import argparse
import logging
import pathlib

import dotenv
import torch

import eval.data as data
import eval.settings as settings
import eval.util as util


def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()
    config_path = util.DirectoryManager.get_config_path(args.dir)
    logging.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())
    directory_manager = util.DirectoryManager(args.dir)

    # Load raw data
    logging.info("Loading raw data")
    dataset_loader = config.base_dataset.build_loader()
    dataset_loader.prepare_raw_data()
    train_images_full, train_targets_full = dataset_loader.load_train_data()
    canary_indices, _ = data.select_canary_indices(
        num_canaries=config.num_canaries,
        num_samples=train_images_full.shape[0],
        global_seed=config.global_seed,
        manual_selection=args.manual_canary_selection,
    )

    canary_generator = config.canaries.build_generator(directory_manager, dataset_loader)

    logging.info("Generating canaries")
    canaries, targets = canary_generator.generate(
        num_canaries=config.num_canaries,
        image_shape=config.base_dataset.get_image_shape(),
        num_classes=config.base_dataset.get_num_classes(),
        replaced_images=train_images_full[canary_indices],
        replaced_targets=train_targets_full[canary_indices],
        global_seed=config.global_seed,
    )
    util.validate_canaries(
        canaries,
        targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_classes=config.base_dataset.get_num_classes(),
        num_canaries=config.num_canaries,
    )

    canaries_images_path = directory_manager.get_canaries_images_path()
    canaries_targets_path = directory_manager.get_canaries_targets_path()
    torch.save(canaries, canaries_images_path)
    torch.save(targets, canaries_targets_path)
    logging.info("Saved baseline canaries to %s and %s", canaries_images_path, canaries_targets_path)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", type=pathlib.Path, required=True, help="Path to experiment base directory")
    parser.add_argument(
        "--manual-canary-selection",
        type=int,
        nargs="+",
        default=None,
        help="Manually select canary indices from the dataset (overrides random selection with global seed)",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main()
