import argparse
import logging
import pathlib

import dotenv

### SPEEDUPS FOR JAX (optional) ###
import jax
import numpy as np
import torch
from eval import data, settings, util

import os

# 1) Persistent compilation cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
# 2) Use TF32/fast math on A100 (default is usually "fast", keep it)
jax.config.update("jax_default_matmul_precision", "tensorfloat32")

# # 3) (Optional) small warmup to compile the hot paths right away
# _ = train_model(  # your BN-aware fn
#     state, batch_stats, epoch_perms[:1],   # 1 epoch slice with same shapes
#     train_images[:1], train_targets[:1],
#     test_images=None, test_targets=None,
#     use_dp=False, verbose=False,
# )

def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()
    log = logging.getLogger("eval_optimize")
    config_path = util.DirectoryManager.get_config_path(args.dir)
    log.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())
    directory_manager = util.DirectoryManager(args.dir)

    if not isinstance(config.canaries, (settings.OptimizedCanary, settings.IdentityCanary)):
        raise NotImplementedError("More canary types not implemented yet")

    log.info(
        "Optimizing canaries with indices %d to %d (exclusive)",
        args.canary_indices.start,
        args.canary_indices.stop,
    )
    log.info("Base dataset: %s", config.base_dataset.name)
    log.info("Canary type: %s", config.canaries.canary_type)
    log.info("Canary optimizer: %s", config.canaries.optimizer.optimizer_type)

    # Load raw data
    dataset_loader = config.base_dataset.build_loader()
    dataset_loader.prepare_raw_data()
    train_images_full, train_targets_full = dataset_loader.load_train_data()
    data.validate_dataset(
        train_images_full,
        train_targets_full,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=config.base_dataset.get_num_train_samples(),
        num_classes=config.base_dataset.get_num_classes(),
    )
    val_images, val_targets = dataset_loader.load_val_data()
    data.validate_dataset(
        val_images,
        val_targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=val_images.shape[0],
        num_classes=config.base_dataset.get_num_classes(),
    )

    # Generate replacement indices for optimizers that start with real data
    canary_indices_in_dataset, _ = data.select_canary_indices(
        num_canaries=config.num_canaries,
        num_samples=train_images_full.shape[0],
        global_seed=config.global_seed,
        manual_selection=args.manual_canary_selection,
    )
    
    log.info("Selecting canary index in dataset %d", canary_indices_in_dataset)
    # Canaries are initialized and stored as values in [0, 1],
    #  but optimization and training might use standardization

    log.info("Preparing data")
    canary_optimizer = config.canaries.optimizer.build_optimizer(config.sample_non_canaries)
    # DEBUG: ensure correct class and method resolution
    _DEBUG = False
    if _DEBUG:
        from eval.optimizers.base import CanaryOptimizer
        from eval.optimizers.unrolled import UnrolledOptimizer

        print("Optimizer type:", type(canary_optimizer))
        print("MRO:", canary_optimizer.__class__.__mro__)

        # Must be concrete subclass
        assert isinstance(canary_optimizer, UnrolledOptimizer)

        # Ensure the bound method is NOT the abstract one
        base_prepare = CanaryOptimizer.prepare_data
        impl_prepare = type(canary_optimizer).prepare_data
        assert impl_prepare is not base_prepare, "prepare_data not overridden in UnrolledOptimizer"

    canary_optimizer.prepare_data(
        train_images_full,
        train_targets_full,
        val_images,
        val_targets,
        dataset_mean_std=dataset_loader.dataset_mean_std,
        max_test_samples=512,
        test_sample_seed=123456,
    )

    if config.canaries.canary_type == "identity":
        log.info("Identity canary: copying initial canary images without optimization")
        # Arielle TODO: add mislabeled for identity
        # Arielle TODO: clean up this code, ask michael for the "id" implementation
        
        for canary_idx in args.canary_indices:
            ds_idx = int(canary_indices_in_dataset[canary_idx])

            out_path = directory_manager.get_optimized_canary_file(canary_idx)
            out_path.parent.mkdir(parents=True, exist_ok=True)

            # If config.canaries.canary_label is specified, override dataset label
            if config.canaries.canary_label is not None:
                target_label = int(config.canaries.canary_label)
                log.info(
                    "Overriding dataset label %d with config-specified label %d for canary %d",
                    int(train_targets_full[ds_idx].item()),
                    target_label,
                    canary_idx,
                )
            else:
                target_label = int(train_targets_full[ds_idx].item())

            torch.save(
                {
                    "image": train_images_full[ds_idx].to(torch.float32).cpu(),
                    "target": target_label,
                },
                out_path,
            )
        return

    rng = np.random.default_rng(config.global_seed)
    canary_seeds = rng.integers(0, np.iinfo(np.int32).max, size=config.num_canaries)
    
    if config.canaries.canary_label is None:
        
        old_canary_targets = rng.integers(0, config.base_dataset.get_num_classes(), size=config.num_canaries)
        log.info("Getting old canary target %d", old_canary_targets)
        
        idxs = np.array(args.canary_indices, dtype=int)

        # Gather true labels for the selected dataset indices (0-based labels assumed)
        true_labels = np.array(
            [int(train_targets_full[canary_indices_in_dataset[i]].item()) for i in idxs],
            dtype=np.int64,
        )

        # Sample uniformly from the other K-1 classes for each element (skip-by-shift trick)
        u = rng.integers(0, config.base_dataset.get_num_classes() - 1, size=true_labels.shape, dtype=np.int64)  # 0..K-2
        canary_targets = u + (u >= true_labels)  # maps to {0..y-1, y+1..K-1}
        log.info("Randomly choosing canary target=%d", canary_targets)
    else:
        canary_targets = np.full(config.num_canaries, config.canaries.canary_label)
    canary_targets = torch.from_numpy(canary_targets).to(torch.int64)
    assert canary_targets.shape == (config.num_canaries,)
   
    for canary_idx in args.canary_indices:
        log.info("Optimizing canary %d (seed %d)", canary_idx, canary_seeds[canary_idx])
        canary_in_dataset_idx = int(canary_indices_in_dataset[canary_idx])
        log.info("Canary index in dataset: %d", canary_in_dataset_idx)
        current_image, current_target = canary_optimizer.optimize(
            target=canary_targets[canary_idx].item(),
            seed=canary_seeds[canary_idx],
            canary_idx=canary_in_dataset_idx,
            intermediate_log_dir=directory_manager.get_canary_log_dir(canary_idx),
        )
        assert current_image.shape == config.base_dataset.get_image_shape()
        assert current_image.min() >= 0
        assert current_image.max() <= 1
        assert current_image.dtype == torch.float32
        assert isinstance(current_target, int)

        # Save canary
        torch.save(
            {
                "image": current_image,
                "target": current_target,
            },
            directory_manager.get_optimized_canary_file(canary_idx),
        )

    log.info("Finished optimizing all canaries")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dir",
        type=pathlib.Path,
        required=True,
        help="Path to experiment base directory",
    )
    parser.add_argument(
        "--canary-indices",
        type=util.model_indices_type,
        required=True,
        help="Canary indices to optimize. Can be a single index (e.g. '1') or a range (e.g. '1-5') with end exclusive",
    )
    parser.add_argument(
        "--manual-canary-selection",
        type=int,
        nargs="+",
        default=None,
        help="Manually select canary indices from the dataset (overrides random selection with global seed)",
    )

    return parser.parse_args()


if __name__ == "__main__":
    main()
