import argparse
import json
import logging
import pathlib
import warnings

import dotenv
import numpy as np
import optax
import torch

import eval.data as data
import eval.settings as settings
import eval.util as util

USE_SCHEDULER = True

def _jax_model_to_serializable(model_tuple):
    """Strip non-picklable pieces and device-get arrays."""
    import jax
    import numpy as np

    state, batch_stats = model_tuple

    def to_np(x):
        return np.asarray(jax.device_get(x))

    params_np = jax.tree_util.tree_map(to_np, state.params)
    batch_stats_np = jax.tree_util.tree_map(to_np, batch_stats)

    # Optional: keep optimizer *state* only if you plan to resume training;
    # it's just arrays / scalars and is picklable.
    opt_states_np = None
    if hasattr(state, "opt_states") and state.opt_states is not None:
        opt_states_np = jax.tree_util.tree_map(to_np, state.opt_states)

    return {
        "params": params_np,
        "batch_stats": batch_stats_np,
        "opt_states": opt_states_np,  # can be None
        # DO NOT store state.txs (contains functions/closures)
    }

# -----------------------------------------------------------------------------
# Inline trainer: thin adapter around unrolled_canaries.* (no duplication)
# -----------------------------------------------------------------------------
class InlineUnrolledJaxTrainer:
    """
    Minimal trainer built on unrolled_canaries.* helpers.
    - BN-aware (threads batch_stats everywhere)
    - Supports MLP / ResNet9 / ResNet50 via your architectures module
    - Standardizes using dataset mean/std if standardize=True
    - Keeps same .train/.predict API as before
    """

    def __init__(
        self,
        *,
        learning_rate: float,
        momentum: float,
        num_epochs: int,
        batch_size: int,
        standardize: bool,
        dp_params,
        architecture_name: str,
        architecture_kwargs: dict,
        eval_batch_size: int = 512,   # <-- NEW: safe eval chunk size
        weight_decay: float | None = None,  # <-- NEW: optional weight decay
        label_smoothing: float = 0.0,  # <-- NEW: optional label smoothing
    ):
        self.learning_rate = float(learning_rate)
        self.momentum = float(momentum)
        self.num_epochs = int(num_epochs)
        self.batch_size = int(batch_size)
        self.standardize = bool(standardize)
        self.dp_params = dp_params
        self.architecture_name = architecture_name.lower()
        self.architecture_kwargs = dict(architecture_kwargs)
        self.images_mean_std = None  # set by caller
        self.eval_batch_size = int(eval_batch_size)
        self.weight_decay = None if weight_decay is None else float(weight_decay)
        self.label_smoothing = float(label_smoothing)

        # Resolve architecture ctor from your module
        from architectures import MLP
        ArchMap = {"mlp": MLP}
        try:
            from architectures import ResNet9
            ArchMap["resnet9"] = ResNet9
        except Exception:
            pass
        try:
            from architectures import ResNet18
            ArchMap["resnet18"] = ResNet18
        except Exception:
            pass
        try:
            from architectures import ResNet50
            ArchMap["resnet50"] = ResNet50
        except Exception:
            pass
        try:
            from architectures import WideResNet
            ArchMap["wrn16_4"] = WideResNet
        except Exception:
            pass
        if self.architecture_name not in ArchMap:
            raise ValueError(f"Unknown architecture '{self.architecture_name}'")
        self.arch = ArchMap[self.architecture_name](**self.architecture_kwargs)

        # Cache JAX helpers
        import jax
        import jax.numpy as jnp
        from unrolled_canaries import (
            create_train_state,
            generate_model_perms,
            train_model as _train_model,
            get_logits as _get_logits,
        )

        self._jax = jax
        self._jnp = jnp
        self._create_train_state = create_train_state
        self._generate_model_perms = generate_model_perms
        self._train_model = _train_model
        self._get_logits = _get_logits

    # ----- internal: preprocessing -----
    def _prep(self, images: torch.Tensor, targets: torch.Tensor | None):
        assert self.images_mean_std is not None, "images_mean_std must be set before training/prediction"

        # (N,C,H,W)[torch] -> CPU numpy -> JAX (N,H,W,C)
        np_images = images.permute(0, 2, 3, 1).contiguous().cpu().numpy()
        jx_images = self._jnp.array(np_images)
        jx_targets = None if targets is None else self._jnp.array(targets.cpu().numpy(), dtype=self._jnp.int32)

        if self.standardize:
            mean_t, std_t = self.images_mean_std
            mean = self._jnp.array(mean_t.cpu().numpy(), dtype=self._jnp.float32)
            std = self._jnp.array(std_t.cpu().numpy(), dtype=self._jnp.float32)
            mean_hwC = mean.reshape(1, 1, 1, -1)
            std_hwC = self._jnp.maximum(std, 1e-8).reshape(1, 1, 1, -1)
            jx_images = (jx_images - mean_hwC) / std_hwC

        return jx_images, jx_targets

    # ----- API: train/predict -----
    def train(
        self,
        images: torch.Tensor,
        targets: torch.Tensor,
        seed: int,
        device: torch.device,  # kept for API parity; unused (JAX decides backend)
    ):
        key = self._jax.random.PRNGKey(int(seed))
        jx_images, jx_targets = self._prep(images, targets)

        # scheduler
        # compute steps
        steps_per_epoch = (int(jx_images.shape[0]) // self.batch_size)
        total_steps = steps_per_epoch * self.num_epochs
        warmup_steps = max(5 * steps_per_epoch, 1)  # ~5 epochs warmup

        # cosine schedule from 0 -> peak -> 1% of peak
        if USE_SCHEDULER:
            lr_schedule = optax.warmup_cosine_decay_schedule(
                init_value=0.0,
                peak_value=self.learning_rate,        # 0.1 from config
                warmup_steps=warmup_steps,
                decay_steps=max(total_steps - warmup_steps, 1),
                end_value=self.learning_rate * 0.01,  # ~1e-3 final
            )
        else:
            lr_schedule = self.learning_rate  # constant

        # Init (BN-aware): returns (state, batch_stats)
        key, key_init = self._jax.random.split(key)
        image_shape_hwc = tuple(jx_images.shape[1:])
        state, batch_stats = self._create_train_state(
            (key_init,),
            learning_rate=lr_schedule,
            momentum=self.momentum,
            num_models=1,
            arch=self.arch,
            image_shape=image_shape_hwc,
            dp_params=self.dp_params,
            weight_decay=self.weight_decay,
        )

        # Epoch permutations for given subset
        key, key_perms = self._jax.random.split(key)
        epoch_perms = self._generate_model_perms(
            key=key_perms,
            num_models=1,
            num_epochs=self.num_epochs,
            num_samples=int(jx_images.shape[0]),
            batch_size=self.batch_size,
            sample_non_canaries=False,
            canary_idx=None,
        )

        # Train using unrolled_canaries.train_model (no duplication)
        state, batch_stats, _, _, train_acc = self._train_model(
            state=state,
            batch_stats=batch_stats,
            epoch_perms=epoch_perms,
            train_images=jx_images,
            train_targets=jx_targets,
            test_images=None,
            test_targets=None,
            use_dp=bool(self.dp_params is not None),
            verbose=False,
            label_smoothing=self.label_smoothing,
        )

        aux = (float(self._jax.device_get(train_acc).mean()),)
        # Return (state, batch_stats) together so we can predict with BN
        return (state, batch_stats), aux

    def predict(
        self,
        images: torch.Tensor,
        model,
        aux: tuple[float, ...] | None,
    ) -> torch.Tensor:
        (state, batch_stats) = model
        target_device = images.device

        # process in chunks to avoid huge cudnn workspaces / OOM
        N = images.shape[0]
        bs = self.eval_batch_size
        out_chunks = []
        for start in range(0, N, bs):
            end = min(start + bs, N)
            imgs_chunk = images[start:end]

            jx_images, _ = self._prep(imgs_chunk, None)  # (M,H,W,C) JAX
            # Eval logits (BN in inference mode)
            logits_per_model = self._get_logits(state, jx_images, batch_stats, train=False)
            logits_chunk = torch.from_numpy(np.array(logits_per_model["model_0"])).to(torch.float32)
            out_chunks.append(logits_chunk)

        logits = torch.cat(out_chunks, dim=0).to(target_device)
        return logits


# -----------------------------------------------------------------------------
# Builder that reads ONLY canaries.optimizer (new unified config)
# -----------------------------------------------------------------------------
def _resolve_arch_from_optimizer(opt):
    """Return (model_name, model_kwargs), supporting both new and legacy shapes."""
    name = getattr(opt, "model_name", None)
    kwargs = getattr(opt, "model_kwargs", None)
    if name is None:
        arch = getattr(opt, "architecture", None)
        if arch is None:
            raise ValueError(
                "No architecture provided. Expected optimizer.model_name/model_kwargs "
                "or optimizer.architecture{architecture_name,...}."
            )
        if isinstance(arch, dict):
            name = arch.get("architecture_name")
            kwargs = {k: v for k, v in arch.items() if k != "architecture_name"}
        else:
            name = getattr(arch, "architecture_name", None)
            kwargs = {
                k: getattr(arch, k)
                for k in dir(arch)
                if not k.startswith("_") and k not in {"architecture_name"} and hasattr(arch, k)
            }
    if not name:
        raise ValueError("An architecture name is required (e.g., 'mlp', 'resnet9', 'resnet50').")
    return str(name).lower(), (kwargs or {})

def build_trainer_for_training(config: settings.Settings, dataset_loader):
    """
    Build the trainer used to train target/shadow models.

    Priority:
      1) If canaries.* has an 'optimizer' with optimizer_type == 'unrolled'
         (works for both OptimizedCanary and IdentityCanary), build the inline
         JAX trainer from that optimizer block so we keep one unified config.
      2) Otherwise, fall back to config.model_trainer (legacy trainers).
      3) If neither is available, raise a clear error.
    """
    # Try unified path first: use the optimizer block if it's unrolled
    canaries = getattr(config, "canaries", None)
    if canaries is not None:
        opt = getattr(canaries, "optimizer", None)
        if opt is not None and getattr(opt, "optimizer_type", None) == "unrolled":
            trainer = build_trainer_from_optimizer(config)
            if hasattr(trainer, "images_mean_std"):
                trainer.images_mean_std = dataset_loader.dataset_mean_std
            return trainer

    # Fallback: legacy trainer block (required when no usable optimizer block)
    if config.model_trainer is not None:
        trainer = config.model_trainer.build_trainer()
        if hasattr(trainer, "images_mean_std"):
            trainer.images_mean_std = dataset_loader.dataset_mean_std
        return trainer

    # Nothing to build from — explain what to provide
    msg = (
        "Unable to build trainer. Provide either:\n"
        "  - canaries.optimizer with optimizer_type='unrolled' (preferred, works for canary_type='optimized' or 'identity'), or\n"
        "  - a model_trainer block (legacy)."
    )
    raise ValueError(msg)

def build_trainer_from_optimizer(config):
    """Construct the inline JAX trainer from the unified optimizer block (no legacy jax_trainer)."""
    if not hasattr(config, "canaries") or config.canaries is None:
        raise ValueError("Missing 'canaries' in config.")
    if not hasattr(config.canaries, "optimizer") or config.canaries.optimizer is None:
        raise ValueError("Missing 'canaries.optimizer' in config.")

    opt = config.canaries.optimizer
    if getattr(opt, "optimizer_type", None) != "unrolled":
        raise ValueError(f"Unsupported optimizer_type={getattr(opt, 'optimizer_type', None)!r} — expected 'unrolled'.")

    model_name, model_kwargs = _resolve_arch_from_optimizer(opt)
    # Back-compat: allow optimizer.mlp_width if user picked MLP but omitted width
    if model_name == "mlp" and "width" not in model_kwargs:
        w = getattr(opt, "mlp_width", None)
        if w is not None:
            model_kwargs["width"] = int(w)

    trainer = InlineUnrolledJaxTrainer(
        learning_rate=opt.learning_rate,
        momentum=opt.momentum,
        num_epochs=opt.num_epochs,
        batch_size=opt.batch_size,
        standardize=opt.standardize,
        dp_params=getattr(opt, "dp_params", None),
        architecture_name=model_name,
        architecture_kwargs=model_kwargs,
        weight_decay=getattr(opt, "weight_decay", None),
        label_smoothing=getattr(opt, "label_smoothing", 0.0),
    )
    logging.info(
        "Trainer built (unrolled_canaries-backed): arch=%s kwargs=%s | lr=%.4g mom=%.3f epochs=%d batch=%d std=%s",
        model_name,
        json.dumps(model_kwargs),
        opt.learning_rate,
        opt.momentum,
        opt.num_epochs,
        opt.batch_size,
        opt.standardize,
    )
    return trainer


# -----------------------------------------------------------------------------
# Helpers for early skip
# -----------------------------------------------------------------------------
def _artifact_paths_for_index(directory_manager, config, model_idx):
    """Return (model_file, predictions_file, metrics_file) for a given absolute model_idx."""
    if model_idx < config.num_models_target:
        model_file = directory_manager.get_target_model_file(model_idx)
        predictions_file = directory_manager.get_target_predictions_file(model_idx)
        metrics_file = directory_manager.get_target_metrics_file(model_idx)
    else:
        shadow_model_idx = model_idx - config.num_models_target
        model_file = directory_manager.get_shadow_model_file(shadow_model_idx)
        predictions_file = directory_manager.get_shadow_predictions_file(shadow_model_idx)
        metrics_file = directory_manager.get_shadow_metrics_file(shadow_model_idx)
    return model_file, predictions_file, metrics_file


def _filter_indices_to_process(all_indices, directory_manager, config, resume: bool):
    """
    Compute which indices actually need work.
    If resume: skip when model+predictions+metrics all exist.
    If not resume: process everything requested.
    """
    to_process = []
    for model_idx in all_indices:
        if model_idx >= config.num_models_target + config.num_models_shadow:
            warnings.warn(
                f"Model index {model_idx} is out of range "
                f"for {config.num_models_target} target and {config.num_models_shadow} shadow models"
            )
            break

        if not resume:
            to_process.append(model_idx)
            continue

        model_file, predictions_file, metrics_file = _artifact_paths_for_index(directory_manager, config, model_idx)
        has_model = model_file.exists()
        has_predictions = predictions_file.exists()
        has_metrics = metrics_file.exists()

        if has_model and has_predictions and has_metrics:
            logging.info("Skipping model %d because it already exists", model_idx)
            continue
        elif has_model or has_predictions or has_metrics:
            warnings.warn(
                f"Model {model_idx} has partial results ({has_model=}, {has_predictions=}, {has_metrics=}); redoing"
            )
        to_process.append(model_idx)
    return to_process


# -----------------------------------------------------------------------------
# Original script main (restructured: early skip happens ASAP)
# -----------------------------------------------------------------------------
def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()

    config_path = util.DirectoryManager.get_config_path(args.dir)
    logging.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())

    directory_manager = util.DirectoryManager(args.dir)

    # -------- EARLY SKIP / RESUME PASS (no GPU, no heavy loads) --------
    requested = list(args.model_indices)
    logging.info("Requested model indices: %s", requested)
    to_process = _filter_indices_to_process(requested, directory_manager, config, resume=args.resume)

    if not to_process:
        logging.info("Nothing to do — all requested models are already completed.")
        return

    logging.info("Will process indices: %s", to_process)

    # -------- Load canaries and datasets only if there's something to do --------
    logging.info("Base dataset: %s", config.base_dataset.name)
    if args.no_store_models:
        logging.info("Storing only predictions, not models")

    # Load canaries
    canaries_images_path = directory_manager.get_canaries_images_path()
    canaries_targets_path = directory_manager.get_canaries_targets_path()
    canary_images = torch.load(canaries_images_path)
    canary_targets = torch.load(canaries_targets_path)
    util.validate_canaries(
        canary_images,
        canary_targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_classes=config.base_dataset.get_num_classes(),
        num_canaries=config.num_canaries,
    )

    # Load raw data
    dataset_loader = config.base_dataset.build_loader()
    dataset_loader.prepare_raw_data()
    train_images_full, train_targets_full = dataset_loader.load_train_data()
    data.validate_dataset(
        train_images_full,
        train_targets_full,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=config.base_dataset.get_num_train_samples(),
        num_classes=config.base_dataset.get_num_classes(),
    )
    val_images, val_targets = dataset_loader.load_val_data()
    data.validate_dataset(
        val_images,
        val_targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=val_images.shape[0],
        num_classes=config.base_dataset.get_num_classes(),
    )

    # Move all raw data to GPU for efficiency
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_images_full = train_images_full.to(device)
    train_targets_full = train_targets_full.to(device)
    val_images = val_images.to(device)
    val_targets = val_targets.to(device)
    canary_images = canary_images.to(device)
    canary_targets = canary_targets.to(device)

    # Select subset of samples to be replaced by canaries
    _, non_canary_indices = data.select_canary_indices(
        num_canaries=config.num_canaries,
        num_samples=train_images_full.shape[0],
        global_seed=config.global_seed,
        manual_selection=args.manual_canary_selection,
    )
    train_images_non_canary = train_images_full[non_canary_indices]
    train_targets_non_canary = train_targets_full[non_canary_indices]

    # Combine canaries + non-canaries
    train_images = torch.cat([canary_images, train_images_non_canary], dim=0)
    train_targets = torch.cat([canary_targets, train_targets_non_canary], dim=0)
    assert train_images.shape == (config.num_canaries + non_canary_indices.shape[0], *train_images_full.shape[1:])
    assert train_targets.shape == (config.num_canaries + non_canary_indices.shape[0],)
    data.validate_dataset(
        train_images,
        train_targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=config.base_dataset.get_num_train_samples(),  # size preserved
        num_classes=config.base_dataset.get_num_classes(),
    )

    # Generate membership masks (full), we’ll index per model
    membership_masks_targets, membership_masks_shadow = data.generate_full_membership_masks(
        num_canaries=config.num_canaries,
        num_non_canaries=non_canary_indices.shape[0],
        num_models_target=config.num_models_target,
        num_models_shadow=config.num_models_shadow,
        sample_non_canaries=config.sample_non_canaries,
        global_seed=config.global_seed,
    )
    util.validate_membership_masks(
        membership_masks_targets,
        num_canaries=config.num_canaries,
        num_non_canaries=non_canary_indices.shape[0],
        num_models=config.num_models_target,
        sample_non_canaries=config.sample_non_canaries,
    )
    util.validate_membership_masks(
        membership_masks_shadow,
        num_canaries=config.num_canaries,
        num_non_canaries=non_canary_indices.shape[0],
        num_models=config.num_models_shadow,
        sample_non_canaries=config.sample_non_canaries,
    )

    # Build trainer from unified optimizer block (no legacy trainer)
    model_trainer = build_trainer_for_training(config, dataset_loader)
    if hasattr(model_trainer, "images_mean_std"):
        model_trainer.images_mean_std = dataset_loader.dataset_mean_std
    # --- NEW: optional training batch size override ---
    if args.train_batch_size is not None:
        if args.train_batch_size <= 0:
            raise ValueError("--train-batch-size must be a positive integer")
        old_bs = getattr(model_trainer, "batch_size", None)
        model_trainer.batch_size = int(args.train_batch_size)
        logging.info(
            "Overriding training batch size: %s -> %d",
            old_bs if old_bs is not None else "(unset)",
            model_trainer.batch_size,
        )

    rng_training = np.random.default_rng(config.global_seed)
    training_seeds = rng_training.integers(0, 2**32, size=config.num_models_target + config.num_models_shadow)

    # -------- Actual training/prediction loop only over 'to_process' --------
    for model_idx in to_process:
        # Target vs Shadow mapping
        if model_idx < config.num_models_target:
            membership_mask = membership_masks_targets[model_idx]
            model_file = directory_manager.get_target_model_file(model_idx)
            predictions_file = directory_manager.get_target_predictions_file(model_idx)
            metrics_file = directory_manager.get_target_metrics_file(model_idx)
        else:
            shadow_model_idx = model_idx - config.num_models_target
            membership_mask = membership_masks_shadow[shadow_model_idx]
            model_file = directory_manager.get_shadow_model_file(shadow_model_idx)
            predictions_file = directory_manager.get_shadow_predictions_file(shadow_model_idx)
            metrics_file = directory_manager.get_shadow_metrics_file(shadow_model_idx)

        current_train_images, current_train_targets = data.build_train_data(
            train_images,
            train_targets,
            membership_mask,
        )

        logging.info("Training model %d with seed %d", model_idx, training_seeds[model_idx])
        current_model, current_aux = model_trainer.train(
            current_train_images,
            current_train_targets,
            seed=int(training_seeds[model_idx]),
            device=device,
        )
        
        model_file.parent.mkdir(parents=True, exist_ok=True)
        serial_model = None
        if not args.no_store_models:
            serial_model = _jax_model_to_serializable(current_model)
        torch.save(
            {
                "model": serial_model,  # (state, batch_stats) tuple
                "aux": current_aux,
                "seed": int(training_seeds[model_idx]),
            },
            model_file,
        )

        # Predict: on base, canaries, val
        pred_base = model_trainer.predict(train_images_full, model=current_model, aux=current_aux)
        pred_canaries = model_trainer.predict(canary_images, model=current_model, aux=current_aux)
        pred_val = model_trainer.predict(val_images, model=current_model, aux=current_aux)
        assert pred_base.shape == (train_images_full.shape[0], config.base_dataset.get_num_classes())
        assert pred_canaries.shape == (canary_images.shape[0], config.base_dataset.get_num_classes())
        assert pred_val.shape == (val_images.shape[0], config.base_dataset.get_num_classes())
        assert pred_base.dtype == torch.float32
        assert pred_canaries.dtype == torch.float32
        assert pred_val.dtype == torch.float32
        predictions_file.parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {
                "pred_base": pred_base.cpu(),
                "pred_canaries": pred_canaries.cpu(),
                "pred_val": pred_val.cpu(),
            },
            predictions_file,
        )

        # Evaluate
        accuracy_base, loss_base = evaluate_predictions(pred_base, train_targets_full)
        accuracy_canaries, loss_canaries = evaluate_predictions(pred_canaries, canary_targets)
        accuracy_val, loss_val = evaluate_predictions(pred_val, val_targets)
        metrics_file.parent.mkdir(parents=True, exist_ok=True)
        metrics_dict = {
            "accuracy_base": accuracy_base,
            "loss_base": loss_base,
            "accuracy_canaries": accuracy_canaries,
            "loss_canaries": loss_canaries,
            "accuracy_val": accuracy_val,
            "loss_val": loss_val,
        }
        metrics_file.write_text(json.dumps(metrics_dict))

    logging.info("Finished training all requested models")


def evaluate_predictions(predictions: torch.Tensor, targets: torch.Tensor) -> tuple[float, float]:
    accuracy = torch.mean((predictions.argmax(dim=-1) == targets).float()).item()
    loss = torch.nn.functional.cross_entropy(predictions, targets).item()
    return accuracy, loss


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", type=pathlib.Path, required=True, help="Path to experiment base directory")
    parser.add_argument(
        "--model-indices",
        type=util.model_indices_type,
        required=True,
        help="Model indices to process. Can be a single index (e.g. '1') or a range (e.g. '1-5') with end exclusive",
    )
    parser.add_argument("--no-store-models", action="store_true", help="Do not store models, only predictions")
    parser.add_argument("--resume", action="store_true", help="Skip training of fully completed runs")

    # NEW: override training batch size
    parser.add_argument(
        "--train-batch-size",
        type=int,
        default=None,
        help="Override the training batch size from config (must be a positive integer)",
    )
    
    parser.add_argument(
        "--manual-canary-selection",
        type=int,
        nargs="+",
        default=None,
        help="Manually select canary indices from the dataset (overrides random selection with global seed)",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main()