# eval/optimizers/unrolled.py

import pathlib
from typing import Any, Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
from canary_opt import fit_canary
from eval.optimizers.base import CanaryOptimizer

# BN-aware helpers (create_train_state now returns (state, batch_stats))
from unrolled_canaries import (  # <- note: from unrolled
    DPParams,
    create_train_state,
    generate_model_perms,
)
from utils import make_summary_writer

ENABLE_LR_SCHEDULE = False


class UnrolledOptimizer(CanaryOptimizer):
    def __init__(
        self,
        *,
        num_models: int,
        learning_rate: float,
        momentum: float,
        num_epochs: int,
        batch_size: int,
        canary_learning_rate: float,
        canary_momentum: float,
        canary_search_steps: int,
        clip_canary: bool,
        loss_type: str,
        loss_agg: str,
        # ---- modular architecture ----
        model_ctor: Callable[..., Any],
        model_kwargs: Optional[Dict[str, Any]] = None,
        # ------------------------------
        fixed_variance: Optional[float],  # kept for API parity (unused in tail-only loss)
        sample_non_canaries: bool,
        standardize: bool,
        dp_params: Optional[DPParams],
        # ---- NEW: TBPTT controls ----
        tbptt_k_steps: int = 4,
        test_eval_every_k: int = 10,
    ):
        self.num_models = num_models
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.canary_learning_rate = canary_learning_rate
        self.canary_momentum = canary_momentum
        self.canary_search_steps = canary_search_steps
        self.clip_canary = clip_canary
        self.loss_type = loss_type
        self.loss_agg = loss_agg

        self.model_ctor = model_ctor
        self.model_kwargs = (model_kwargs or {})

        self.fixed_variance = fixed_variance
        self.sample_non_canaries = sample_non_canaries
        self.standardize = standardize
        self.dp_params = dp_params

        self.tbptt_k_steps = int(tbptt_k_steps)
        self.test_eval_every_k = int(test_eval_every_k)

        # JAX-formatted data cache
        self._train_images = None            # NHWC, possibly standardized
        self._train_images_unstd = None      # NHWC, strictly in [0,1] (for init-from-dataset)
        self._train_targets = None
        self._test_images = None
        self._test_targets = None
        self._dataset_mean = None
        self._dataset_std = None

    def optimize(
        self,
        target: int,
        seed: int,
        canary_idx: int,
        intermediate_log_dir: pathlib.Path,
    ) -> Tuple[torch.Tensor, int]:
        assert self._train_images is not None
        assert self._train_targets is not None
        assert self._dataset_mean is not None
        assert self._dataset_std is not None


        # Keep historical behavior: the canary is *inserted* at the last position
        insert_idx = int(self._train_targets.shape[0] - 1)
        # But we may *initialize* from the user-selected dataset index (canary_idx)
        selected_idx = int(canary_idx)


        key = jax.random.PRNGKey(seed)
        summary_writer = make_summary_writer(intermediate_log_dir, {})

        # Generate permutations (same IN/OUT perms), using the *insert* index
        key, subkey = jax.random.split(key)
        assert self.num_models % 2 == 0, "num_models must be even (half IN, half OUT)."

        # We intentionally do not change how the canary is placed for
        # training/permutation logic (it still replaces the last element).
        # If you want the canary to remain at the same dataset
        # position it’s initialized from, pass that selected_idx
        # to generate_model_perms(...) instead of insert_idx.
        model_perms = generate_model_perms(
            subkey,
            num_models=self.num_models // 2,
            num_epochs=self.num_epochs,
            num_samples=int(self._train_targets.shape[0]),
            batch_size=self.batch_size,
            sample_non_canaries=self.sample_non_canaries,
            canary_idx=insert_idx,
        )
        model_perms_in = model_perms_out = model_perms  # same perms for deterministic dynamics

        # Build architecture
        arch = self.model_ctor(**self.model_kwargs)

        # Initialize model states (BN-aware init returns (state, batch_stats))
        key, subkey_in, subkey_out = jax.random.split(key, 3)
        image_shape_hwc = tuple(self._train_images.shape[1:])  # (H, W, C)


        if ENABLE_LR_SCHEDULE:
            # Warmup steps (for LR schedulers that need it)
            num_samples = int(self._train_targets.shape[0])
            steps_per_epoch = num_samples // self.batch_size
            total_steps = steps_per_epoch * self.num_epochs
            warmup_steps = 5 * steps_per_epoch  # 5 warmup epochs

            lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=self.learning_rate,
            warmup_steps=warmup_steps,
            decay_steps=total_steps - warmup_steps,
            end_value=0.0,
            )
        else:
            lr_schedule = self.learning_rate

        # arch = self.model_ctor(**self.model_kwargs)
        # If this is ResNet9 and DP is enabled, switch to GN
        # if hasattr(arch, "norm") and (self.dp_params is not None):
        #     arch = type(arch)(**{**self.model_kwargs, "num_classes": getattr(arch, "num_classes", 10), "norm": "gn"})

        init_state_in, batch_stats_in = create_train_state(
            jax.random.split(subkey_in, self.num_models // 2),
            learning_rate=lr_schedule,   # <--- pass schedule object
            momentum=self.momentum,
            num_models=self.num_models // 2,
            arch=arch,
            image_shape=image_shape_hwc,
            dp_params=self.dp_params,
        )
        init_state_out, batch_stats_out = create_train_state(
            jax.random.split(subkey_out, self.num_models // 2),
            learning_rate=lr_schedule,   # <--- pass schedule object
            momentum=self.momentum,
            num_models=self.num_models // 2,
            arch=arch,
            image_shape=image_shape_hwc,
            dp_params=self.dp_params,
        )
        del subkey_in, subkey_out

        # -----------------------------
        # Initialize canary image
        # -----------------------------
        key, subkey_image = jax.random.split(key)
        # If we have an unstandardized copy of the dataset, and a valid index, start from that image.
        if (
            self._train_images_unstd is not None
            and 0 <= selected_idx < self._train_images_unstd.shape[0]
        ):
            init_canary = jnp.array(self._train_images_unstd[selected_idx])
            # shape check / fix in case broadcasting/typing surprises
            if init_canary.shape != image_shape_hwc:
                init_canary = init_canary.reshape(image_shape_hwc)
        else:
            # Fallback: random uniform in [0,1]
            init_canary = jax.random.uniform(subkey_image, shape=image_shape_hwc)
        del subkey_image, key

        # Standardization handling + clip ranges
        if self.standardize:
            mean_hwC = self._dataset_mean.reshape(1, 1, -1)
            std_hwC  = self._dataset_std.reshape(1, 1, -1)
            init_canary = (init_canary - mean_hwC) / std_hwC
            clip_range = (
                (-self._dataset_mean / self._dataset_std).reshape(1, 1, -1),
                ((1.0 - self._dataset_mean) / self._dataset_std).reshape(1, 1, -1),
            )
        else:
            clip_range = (0.0, 1.0)

        # Prepare fixtures: (IN state, IN BN, IN perms, OUT state, OUT BN, OUT perms)
        fixtures = (
            init_state_in, batch_stats_in, model_perms_in,
            init_state_out, batch_stats_out, model_perms_out,
        )

        # Canary optimizer for the image itself
        optimizer = optax.sgd(learning_rate=self.canary_learning_rate, momentum=self.canary_momentum)

        # Path to save intermediate canaries/losses
        intermediate_results_path = intermediate_log_dir / "intermediate_canaries.npz"
        intermediate_results_path.parent.mkdir(parents=True, exist_ok=True)

        # Run tail-only TBPTT search (OUT side is precomputed inside fit_canary)
        final_canary = fit_canary(
            canary_params=init_canary,
            optimizer=optimizer,
            canary_search_steps=self.canary_search_steps,
            canary_label=target,
            output_path=intermediate_results_path,
            fixtures=fixtures,
            summary_writer=summary_writer,
            train_images_full=self._train_images,
            train_targets_full=self._train_targets,
            is_mlp=bool(getattr(arch, "is_mlp", False)),
            canary_idx=insert_idx,  # keep insertion at the last position
            clip_canary=self.clip_canary,
            clip_min=clip_range[0],
            clip_max=clip_range[1],
            loss_type=self.loss_type,
            loss_agg=self.loss_agg,
            use_dp=bool(self.dp_params is not None),
            tbptt_k_steps=int(self.tbptt_k_steps),
            # New: eval during search
            test_images=self._test_images,
            test_targets=self._test_targets,
            eval_every_k=int(self.test_eval_every_k),
        )

        # Unstandardize back to [0,1]
        if self.standardize:
            mean_hwC = self._dataset_mean.reshape(1, 1, -1)   # (1,1,C)
            std_hwC  = self._dataset_std.reshape(1, 1, -1)    # (1,1,C)
            final_canary = final_canary * std_hwC + mean_hwC
            final_canary = jnp.clip(final_canary, 0.0, 1.0)

        # Convert to CHW torch tensor
        current_canary_image = torch.from_numpy(np.array(final_canary.transpose(2, 0, 1))).cpu().to(torch.float32)
        return current_canary_image, target

    def prepare_data(
        self,
        train_images: torch.Tensor,   # (N, C, H, W) in [0,1]
        train_labels: torch.Tensor,
        test_images: torch.Tensor,    # (N, C, H, W) in [0,1]
        test_labels: torch.Tensor,
        dataset_mean_std: Tuple[torch.Tensor, torch.Tensor],  # per-channel or CHW
        # ---- NEW: optional test set subsampling for speed ----
        max_test_samples: Optional[int] = None,
        test_sample_seed: Optional[int] = None,
    ) -> None:
        # PyTorch (N, C, H, W) -> JAX NHWC
        train_images = train_images.permute(0, 2, 3, 1).contiguous()
        test_images  = test_images.permute(0, 2, 3, 1).contiguous()

        # --- Cache an unstandardized copy for init-from-dataset ---
        # Keep a *separate* array that always stays in [0,1], NHWC
        self._train_images_unstd = jnp.array(train_images)

        # --- NEW: optionally shrink test set ---
        if max_test_samples is not None and max_test_samples < test_images.shape[0]:
            if test_sample_seed is not None:
                g = torch.Generator().manual_seed(test_sample_seed)
                idx = torch.randperm(test_images.shape[0], generator=g)[:max_test_samples]
            else:
                idx = torch.arange(max_test_samples)
            test_images  = test_images[idx]
            test_labels  = test_labels[idx]

        # Convert to JAX
        self._train_images = jnp.array(train_images)
        self._train_targets = jnp.array(train_labels, dtype=jnp.int16)
        self._test_images  = jnp.array(test_images)
        self._test_targets = jnp.array(test_labels, dtype=jnp.int16)
        mean_t, std_t = dataset_mean_std

        # Ensure per-channel vectors (C,) even if inputs are CHW or HWC
        if mean_t.ndim == 3:             # (C,H,W) or (H,W,C)
            if mean_t.shape[0] in (1,3): # assume CHW
                mean_c = mean_t[:, 0, 0]
                std_c  = std_t[:, 0, 0]
            else:                         # assume HWC
                mean_c = mean_t[0, 0, :]
                std_c  = std_t[0, 0, :]
        elif mean_t.ndim == 1:           # already (C,)
            mean_c = mean_t
            std_c  = std_t
        else:
            # Fallback: flatten to channels
            mean_c = mean_t.view(-1)
            std_c  = std_t.view(-1)

        # Store canonical (C,) and also broadcastable NHWC stats for fast reuse
        self._dataset_mean = jnp.array(mean_c, dtype=jnp.float32)              # (C,)
        self._dataset_std  = jnp.array(jnp.maximum(jnp.array(std_c), 1e-8),    # avoid div-by-zero
                                       dtype=jnp.float32)                      # (C,)

        if self.standardize:
            mean_hwC = self._dataset_mean.reshape(1, 1, 1, -1)  # NHWC-friendly
            std_hwC  = self._dataset_std.reshape(1, 1, 1, -1)
            self._train_images = (self._train_images - mean_hwC) / std_hwC
            self._test_images  = (self._test_images  - mean_hwC) / std_hwC
