from __future__ import annotations

import abc
import typing
from typing import Any, Literal, Optional

import eval.optimizers.influence_stochastic
import eval.trainers.torch_trainer
import influence.torch_architectures
import pydantic
from eval import (
    baseline_canaries,
    data,
    optimizers,  # only for type hints; actual import in build
    trainers,
    util,
)
from pydantic import model_validator
from unrolled_canaries import DPParams as _DPParams  # type reference


class CanaryType(pydantic.BaseModel, metaclass=abc.ABCMeta):
    canary_type: str

    @abc.abstractmethod
    def build_generator(
        self, directory_manager: util.DirectoryManager, dataset_loader: data.DatasetLoader,
    ) -> baseline_canaries.BaselineCanaryGenerator:
        pass

class IdentityCanary(CanaryType):
    canary_type: typing.Literal["identity"]
    canary_label: int | None = None  # allow random if None
    optimizer: typing.Optional[
        UnrolledOptimizerConfig | InfluenceStochasticOptimizerConfig
    ] = None

    def build_generator(
        self, directory_manager: util.DirectoryManager, dataset_loader: data.DatasetLoader,
    ) -> baseline_canaries.InDistributionCanaryGenerator:
        # Reuse the ID generator plumbing; identity does no optimization
        return baseline_canaries.InDistributionCanaryGenerator()

class InDistributionCanary(CanaryType):
    canary_type: typing.Literal["id"]

    def build_generator(
        self, directory_manager: util.DirectoryManager, dataset_loader: data.DatasetLoader,
    ) -> baseline_canaries.InDistributionCanaryGenerator:
        return baseline_canaries.InDistributionCanaryGenerator()


class MislabeledCanary(CanaryType):
    canary_type: typing.Literal["mislabeled"]

    def build_generator(
        self, directory_manager: util.DirectoryManager, dataset_loader: data.DatasetLoader,
    ) -> baseline_canaries.MislabeledCanaryGenerator:
        return baseline_canaries.MislabeledCanaryGenerator()


class CanaryOptimizerConfig(pydantic.BaseModel, metaclass=abc.ABCMeta):
    optimizer_type: str

    @abc.abstractmethod
    def build_optimizer(self, sample_non_canaries: bool) -> optimizers.CanaryOptimizer:
        pass


class DPParams(pydantic.BaseModel):
    noise_multiplier: float
    l2_norm_clip: float
    delta: float



class UnrolledOptimizerConfig(CanaryOptimizerConfig):
    optimizer_type: Literal["unrolled"]
    architecture: Optional[dict] = None  # allow incoming JSON to provide this


    # -------------------------
    # Model training parameters
    # -------------------------
    num_models: int = 4
    """Number of models in the ensemble."""

    learning_rate: float = 0.1
    """Learning rate for the model training."""

    momentum: float = 0.9
    """Momentum for the training."""

    num_epochs: int = 10
    """Number of epochs to train each model for."""

    batch_size: int = 128
    """Batch size."""

    standardize: bool = True
    """Standardize the training data to have zero mean and unit variance."""

    label_smoothing: float = 0.1

    # -------------------------------
    # Canary optimization parameters
    # -------------------------------
    canary_learning_rate: float = 1.0
    """Learning rate for the optimizer."""

    canary_momentum: float = 0.9
    """Momentum for the optimizer."""

    canary_search_steps: int = 300
    """Number of steps to optimize the canary."""

    clip_canary: bool = True
    """Clip the canary to [0, 1] (or standardized bounds if `standardize=True`)."""

    loss_type: Literal["l2", "hinge", "lira"] = "hinge"
    """Loss type for the canary. Can be 'l2', 'hinge' or 'lira'."""

    loss_agg: Literal["max", "mean"] = "mean"
    """Canary loss aggregation over the models. 'lira' loss type has its own aggregation."""

    # ---------------------------------------
    # Model architecture (modular / pluggable)
    # ---------------------------------------
    model_name: Literal["mlp", "resnet9", "resnet18", "resnet50", "wrn16_4"] = "mlp"
    """
    Which architecture to use. Built-ins: 'mlp', 'resnet9', 'resnet18', 'resnet50'.
    """

    model_kwargs: dict[str, Any] = {}
    """
    Keyword args forwarded to the selected architecture.
    Examples:
      - {"width": 512, "depth": 2, "num_classes": 10} for 'mlp'
      - {"num_classes": 10, "stem_channels": 64} for 'resnet9'
      - {"num_classes": 10, "base_channels": 64, "small_input": True} for 'resnet50'
    """

    # ---------------------------------------
    fixed_variance: Optional[float] = None
    """Fixed variance for 'lira' loss. If None, it is estimated from the data."""

    dp_params: Optional[_DPParams] = None
    """If set, use DP-SGD optimizer with the given parameters."""

    @model_validator(mode="before")
    @classmethod
    def _translate_architecture(cls, data):
        if isinstance(data, dict) and "architecture" in data and ("model_name" not in data):
            arch = data.pop("architecture") or {}
            name = arch.pop("architecture_name", None)
            if not name:
                raise ValueError("architecture_name is required in optimizer.architecture")
            data["model_name"] = name.lower()
            data["model_kwargs"] = arch
        return data

    def _resolve_architecture(self):
        """Returns (model_ctor, model_kwargs) based on model_name + model_kwargs.
        """
        # Late import to avoid slow JAX/Flax imports during config construction
        from architectures import MLP, ResNet9, ResNet18, ResNet50, WideResNet

        kwargs = dict(self.model_kwargs)
        if self.model_name == "mlp":
            model_ctor = MLP
        elif self.model_name == "resnet9":
            model_ctor = ResNet9
        elif self.model_name == "resnet18":
            # sensible defaults if user omits them
            model_ctor = ResNet18
        elif self.model_name == "resnet50":
            model_ctor = ResNet50
        elif self.model_name == "wrn16_4":
            # sensible defaults if user omits them
            kwargs.setdefault("depth", 16)
            kwargs.setdefault("widen_factor", 4)
            kwargs.setdefault("num_classes", 10)
            kwargs.setdefault("dropout_rate", 0.0)
            model_ctor = WideResNet
        else:
            raise ValueError(f"Unknown model_name={self.model_name!r}")

        return model_ctor, kwargs

    def build_optimizer(self, sample_non_canaries: bool) -> optimizers.CanaryOptimizer:
        # Put imports here to avoid slow JAX imports at module load
        import eval.optimizers.unrolled
        import unrolled_canaries

        # Map dp_params into the training DPParams class expected by JAX code
        if self.dp_params is None:
            dp_params = None
        else:
            dp_params = unrolled_canaries.DPParams(
                noise_multiplier=self.dp_params.noise_multiplier,
                l2_norm_clip=self.dp_params.l2_norm_clip,
                delta=self.dp_params.delta,
            )

        # Resolve the requested architecture into a constructor + kwargs
        model_ctor, model_kwargs = self._resolve_architecture()

        # Build the optimizer with modular architecture

        _DEBUG=False
        if _DEBUG:
            import eval.optimizers as O
            import eval.optimizers.unrolled as U

            print(">>> DEBUG Settings.build_optimizer")
            print("eval pkg path:", getattr(O, "__file__", None))
            print("unrolled module path:", getattr(U, "__file__", None))
            print("UnrolledOptimizer is at:", U.UnrolledOptimizer.__module__, U.UnrolledOptimizer.__qualname__)
            print("Has prepare_data attribute?", hasattr(U.UnrolledOptimizer, "prepare_data"))
            print("Subclass __abstractmethods__:", getattr(U.UnrolledOptimizer, "__abstractmethods__", None))
            print("Base CanaryOptimizer path:", O.CanaryOptimizer.__module__, getattr(O, "__file__", None))
            print("Base __abstractmethods__:", getattr(O.CanaryOptimizer, "__abstractmethods__", None))
            print("UnrolledOptimizer dict keys (subset):", [k for k in U.UnrolledOptimizer.__dict__.keys() if "prepare" in k or k=="optimize"])

        return eval.optimizers.unrolled.UnrolledOptimizer(
            num_models=self.num_models,
            learning_rate=self.learning_rate,
            momentum=self.momentum,
            num_epochs=self.num_epochs,
            batch_size=self.batch_size,
            canary_learning_rate=self.canary_learning_rate,
            canary_momentum=self.canary_momentum,
            canary_search_steps=self.canary_search_steps,
            clip_canary=self.clip_canary,
            loss_type=self.loss_type,
            loss_agg=self.loss_agg,
            # pass the modular model
            model_ctor=model_ctor,
            model_kwargs=model_kwargs,
            fixed_variance=self.fixed_variance,
            sample_non_canaries=sample_non_canaries,
            standardize=self.standardize,
            dp_params=dp_params,
        )

class InfluenceArchitectureConfig(pydantic.BaseModel, metaclass=abc.ABCMeta):
    architecture_name: str

    @abc.abstractmethod
    def build_architecture(self) -> influence.torch_architectures.TorchArchitecture:
        pass


class HLBAchitectureConfig(InfluenceArchitectureConfig):
    architecture_name: typing.Literal["hlb"]

    base_depth: int = 64
    """Base depth for the HLB architecture."""

    num_epochs: float = 12.1
    """Number of epochs to train each model for."""

    def build_architecture(self) -> influence.torch_architectures.HLBTorchArchitecture:
        return influence.torch_architectures.HLBTorchArchitecture(
            base_depth=self.base_depth, num_epochs=self.num_epochs,
        )


class WideResNetAchitectureConfig(InfluenceArchitectureConfig):
    architecture_name: typing.Literal["wrn"]

    def build_architecture(self) -> influence.torch_architectures.TorchArchitecture:
        raise NotImplementedError("TODO")


class MLPArchitectureConfig(InfluenceArchitectureConfig):
    architecture_name: typing.Literal["mlp"]

    mlp_width: int = 20
    """Width of the MLP."""

    learning_rate: float = 0.1
    """Learning rate for the model training."""

    momentum: float = 0.9
    """Momentum for the training."""

    num_epochs: int = 10
    """Number of epochs to train each model for."""

    batch_size: int = 128
    """Batch size."""

    def build_architecture(self) -> influence.torch_architectures.MLPTorchArchitecture:
        return influence.torch_architectures.MLPTorchArchitecture(
            mlp_width=self.mlp_width,
            learning_rate=self.learning_rate,
            momentum=self.momentum,
            num_epochs=self.num_epochs,
            batch_size=self.batch_size,
        )


class InfluenceStochasticOptimizerConfig(CanaryOptimizerConfig):
    optimizer_type: typing.Literal["influence_stochastic"]
    architecture: HLBAchitectureConfig | WideResNetAchitectureConfig | MLPArchitectureConfig = pydantic.Field(
        discriminator="architecture_name",
    )

    num_optimization_steps: int
    canary_lr: float
    canary_momentum: float  # set to 0 to disable
    clamp: bool = True
    num_jacobian_samples: int | None
    early_stopping_metric: typing.Literal["loss", "hinge_score", "logit_score"] | None = None
    """If set, will return the canary with the maximum ABSOLUTE gap betwen in and out models for the given metric."""
    num_eval_models: int | None = None
    """Number of models to evaluate the canary on and to use for early stopping.
    If not set and early stopping is used, will use one model."""

    def build_optimizer(self, sample_non_canaries: bool) -> optimizers.CanaryOptimizer:
        if self.num_eval_models is not None:
            num_eval_models = self.num_eval_models
        else:
            # Use 1 model if early stopping, else 0
            num_eval_models = 1 if self.early_stopping_metric is not None else 0

        return eval.optimizers.influence_stochastic.InfluenceStochasticOptimizer(
            num_optimization_steps=self.num_optimization_steps,
            canary_lr=self.canary_lr,
            canary_momentum=self.canary_momentum,
            clamp=self.clamp,
            num_jacobian_samples=self.num_jacobian_samples,
            architecture=self.architecture.build_architecture(),
            early_stopping_metric=self.early_stopping_metric,
            num_eval_models=num_eval_models,
        )


class OptimizedCanary(CanaryType):
    """Configuration for optimized canary generation."""

    canary_type: typing.Literal["optimized"]
    optimizer: UnrolledOptimizerConfig | InfluenceStochasticOptimizerConfig = pydantic.Field(
        discriminator="optimizer_type",
    )
    canary_label: int | None = None
    """Label for the canary. If None, labels are sampled uniformly."""

    def build_generator(
        self, directory_manager: util.DirectoryManager, dataset_loader: data.DatasetLoader,
    ) -> util.ConcatOptimizedCanaryGenerator:
        return util.ConcatOptimizedCanaryGenerator(directory_manager)


class RandomCanary(CanaryType):
    canary_type: typing.Literal["random"]

    def build_generator(self, directory_manager: util.DirectoryManager) -> baseline_canaries.RandomCanaryGenerator:
        return baseline_canaries.RandomCanaryGenerator()


class AdversarialCanary(CanaryType):
    canary_type: typing.Literal["adversarial"]

    learning_rate: float = 0.1
    """Learning rate for the model training."""

    momentum: float = 0.9
    """Momentum for the training."""

    num_epochs: int = 10
    """Number of epochs to train each model for."""

    batch_size: int = 128
    """Batch size."""

    standardize: bool = True
    """Standardize the training data to have zero mean and unit variance."""

    sample_non_canaries: bool = True
    """If True, uses 50-50 dataset splits to train the victim ensemble."""

    mlp_width: int = 20
    """Width of the MLP."""

    attack_eps: float = 0.3
    """Epsilon for the adversarial attack (w.r.t. [0, 1])."""

    attack_norm: typing.Literal["linf", "l2"] = "linf"
    """Norm to use for the adversarial attack."""

    attack_num_models: int = 8
    """Number of models to use for the adversarial attack (ensemble size)."""

    attack_aggregation: typing.Literal["mean", "min"] = "min"
    """Aggregation to use for the adversarial attack."""

    def build_generator(
        self, directory_manager: util.DirectoryManager, dataset_loader: data.DatasetLoader,
    ) -> baseline_canaries.AdversarialCanaryGenerator:
        return baseline_canaries.AdversarialCanaryGenerator(
            learning_rate=self.learning_rate,
            momentum=self.momentum,
            num_epochs=self.num_epochs,
            batch_size=self.batch_size,
            standardize=self.standardize,
            sample_non_canaries=self.sample_non_canaries,
            mlp_width=self.mlp_width,
            attack_eps=self.attack_eps,
            attack_norm=self.attack_norm,
            attack_num_models=self.attack_num_models,
            attack_aggregation=self.attack_aggregation,
            dataset_loader=dataset_loader,
        )


class BaseDataset(pydantic.BaseModel, metaclass=abc.ABCMeta):
    name: str

    @abc.abstractmethod
    def get_image_shape(self) -> tuple[int, int, int]:
        pass

    @abc.abstractmethod
    def get_num_classes(self) -> int:
        pass

    @abc.abstractmethod
    def get_num_train_samples(self) -> int:
        pass

    @abc.abstractmethod
    def build_loader(self) -> data.DatasetLoader:
        pass


class CIFAR10(BaseDataset):
    name: typing.Literal["cifar10"]

    def get_image_shape(self) -> tuple[int, int, int]:
        return (3, 32, 32)

    def get_num_classes(self) -> int:
        return 10

    def get_num_train_samples(self) -> int:
        return 50_000

    def build_loader(self) -> data.CIFAR10Loader:
        return data.CIFAR10Loader()


class MNIST(BaseDataset):
    name: typing.Literal["mnist"]

    def get_image_shape(self) -> tuple[int, int, int]:
        return (1, 28, 28)

    def get_num_classes(self) -> int:
        return 10

    def get_num_train_samples(self) -> int:
        return 60_000

    def build_loader(self) -> data.MNISTLoader:
        return data.MNISTLoader()


class ModelTrainerConfig(pydantic.BaseModel, metaclass=abc.ABCMeta):
    trainer_type: str

    @abc.abstractmethod
    def build_trainer(self) -> trainers.ModelTrainer:
        pass


class JaxTrainerConfig(ModelTrainerConfig):
    trainer_type: typing.Literal["jax"]

    # Model training parameters
    learning_rate: float = 0.1
    """Learning rate for the model training."""

    momentum: float = 0.9
    """Momentum for the training."""

    num_epochs: int = 10
    """Number of epochs to train each model for."""

    batch_size: int = 128
    """Batch size."""

    standardize: bool = True
    """Standardize the training data to have zero mean and unit variance."""

    # Model architecture parameters
    mlp_width: int = 20
    """Width of the MLP."""

    dp_params: DPParams | None = None
    """If set, use DP-SGD optimizer with the given parameters."""

    def build_trainer(self) -> trainers.ModelTrainer:
        # Putting import here because importing jax stuff is slow
        import eval.trainers.jax_trainer
        import unrolled

        if self.dp_params is None:
            dp_params = None
        else:
            dp_params = unrolled.DPParams(
                noise_multiplier=self.dp_params.noise_multiplier,
                l2_norm_clip=self.dp_params.l2_norm_clip,
                delta=self.dp_params.delta,
            )
        return eval.trainers.jax_trainer.JaxTrainer(
            learning_rate=self.learning_rate,
            momentum=self.momentum,
            num_epochs=self.num_epochs,
            batch_size=self.batch_size,
            mlp_width=self.mlp_width,
            standardize=self.standardize,
            dp_params=dp_params,
        )


class TorchTrainerConfig(ModelTrainerConfig):
    trainer_type: typing.Literal["torch"]
    architecture: HLBAchitectureConfig | WideResNetAchitectureConfig | MLPArchitectureConfig = pydantic.Field(
        discriminator="architecture_name",
    )

    def build_trainer(self) -> eval.trainers.torch_trainer.TorchTrainer:
        return eval.trainers.torch_trainer.TorchTrainer(
            architecture=self.architecture.build_architecture(),
        )


from typing import Optional, Literal

class Settings(pydantic.BaseModel):
    global_seed: int
    num_canaries: int
    num_models_target: int
    num_models_shadow: int
    canaries: OptimizedCanary | InDistributionCanary | RandomCanary | MislabeledCanary | AdversarialCanary | IdentityCanary = (
        pydantic.Field(discriminator="canary_type")
    )
    sample_non_canaries: bool = False
    base_dataset: CIFAR10 | MNIST = pydantic.Field(discriminator="name")

    # Accept the new value and keep older ones working; default to new behavior.
    canary_task_selection: Literal[
        "classification"
    ] = "classification"

    model_trainer: Optional[JaxTrainerConfig | TorchTrainerConfig] = None

    @model_validator(mode="after")
    def _require_trainer_conditionally(self):
        if isinstance(self.canaries, OptimizedCanary):
            opt = self.canaries.optimizer
            if getattr(opt, "optimizer_type", None) == "influence_stochastic":
                if self.model_trainer is None:
                    raise ValueError(
                        "model_trainer is required when optimizer_type='influence_stochastic'.",
                    )
        return self