import logging
import pathlib

import torch

import eval.baseline_canaries as baseline_canaries


def model_indices_type(indices_str: str) -> range:
    """Parse model indices string into a range.

    Args:
        indices_str: String in format "N" for single index or "N-M" for range (M exclusive)

    Returns:
        A range object
    """
    if "-" in indices_str:
        start, end = map(int, indices_str.split("-"))
        result = range(start, end)
    else:
        result = range(int(indices_str), int(indices_str) + 1)

    if result.start < 0 or result.stop < 0:
        raise ValueError(f"Invalid model indices: {indices_str}")

    return result


class DirectoryManager(object):
    def __init__(self, experiment_dir: pathlib.Path):
        self._experiment_dir = experiment_dir

    @classmethod
    def get_config_path(cls, experiment_dir: pathlib.Path) -> pathlib.Path:
        return experiment_dir / "config.json"

    def get_canary_intermdiate_file(self, canary_idx: int) -> pathlib.Path:
        return self._get_canary_dir(canary_idx) / "intermediate_canaries.npz"

    def get_canary_log_dir(self, canary_idx: int) -> pathlib.Path:
        return self._get_canary_dir(canary_idx) / "logs"

    def get_optimized_canary_file(self, canary_idx: int) -> pathlib.Path:
        return self._get_canary_dir(canary_idx) / "result.pt"

    def get_canaries_images_path(self) -> pathlib.Path:
        return self._experiment_dir / "canaries_images.pt"

    def get_canaries_targets_path(self) -> pathlib.Path:
        return self._experiment_dir / "canaries_targets.pt"

    def get_shadow_model_file(self, model_idx: int) -> pathlib.Path:
        return self._get_shadow_dir(model_idx) / "model.pt"

    def get_target_model_file(self, model_idx: int) -> pathlib.Path:
        return self._get_target_dir(model_idx) / "model.pt"

    def get_shadow_predictions_file(self, model_idx: int) -> pathlib.Path:
        return self._get_shadow_dir(model_idx) / "predictions.pt"

    def get_full_shadow_predictions_file(self) -> pathlib.Path:
        return self._experiment_dir / "shadow_predictions_full.pt"

    def get_target_predictions_file(self, model_idx: int) -> pathlib.Path:
        return self._get_target_dir(model_idx) / "predictions.pt"

    def get_full_target_predictions_file(self) -> pathlib.Path:
        return self._experiment_dir / "target_predictions_full.pt"

    def get_shadow_metrics_file(self, model_idx: int) -> pathlib.Path:
        return self._get_shadow_dir(model_idx) / "metrics.json"

    def get_target_metrics_file(self, model_idx: int) -> pathlib.Path:
        return self._get_target_dir(model_idx) / "metrics.json"

    def get_attack_results_file(self) -> pathlib.Path:
        return self._experiment_dir / "attack_results.npz"

    def get_influence_model_file(self, model_idx: int) -> pathlib.Path:
        return self._get_influence_dir() / "checkpoints" / f"model_{model_idx}.pt"

    def get_influence_scores_file(self, model_idx:int) -> pathlib.Path:
        return self._get_influence_dir() / "scores" / f"scores_{model_idx}.pt"

    def get_influence_results_dir(self, model_idx:int) -> pathlib.Path:
        return self._get_influence_dir() / f"influence_results_{model_idx}"

    def _get_canary_dir(self, canary_idx: int) -> pathlib.Path:
        return self._experiment_dir / "canaries" / str(canary_idx)

    def _get_shadow_dir(self, model_idx: int) -> pathlib.Path:
        return self._experiment_dir / "shadow_models" / str(model_idx)

    def _get_target_dir(self, model_idx: int) -> pathlib.Path:
        return self._experiment_dir / "target_models" / str(model_idx)
    
    def _get_influence_dir(self) -> pathlib.Path:
        return self._experiment_dir / "influence_functions"

def validate_canaries(
    canaries: torch.Tensor,
    targets: torch.Tensor,
    image_shape: tuple[int, int, int],
    num_classes: int,
    num_canaries: int,
) -> None:
    """Validate canary images and targets.

    Args:
        canaries: Tensor of canary images
        targets: Tensor of canary targets

    Raises:
        AssertionError: If any validation check fails
    """
    assert canaries.shape == (num_canaries, *image_shape) and canaries.dtype == torch.float32
    assert canaries.min() >= 0.0 and canaries.max() <= 1.0
    assert targets.shape == (num_canaries,) and targets.dtype == torch.int64
    assert targets.min() >= 0 and targets.max() < num_classes


def validate_membership_masks(
    membership_masks: torch.Tensor,
    num_canaries: int,
    num_non_canaries: int,
    num_models: int,
    sample_non_canaries: bool,
) -> None:
    assert membership_masks.shape == (num_models, num_canaries + num_non_canaries)
    assert membership_masks.dtype == torch.bool
    assert torch.all(
        membership_masks[:, :num_canaries].sum(dim=0) == num_models // 2
    )  # each canary in exactly half of the models
    if sample_non_canaries:
        assert torch.all(
            membership_masks[:, num_canaries:].sum(dim=0) == num_models // 2
        )  # each non-canary in exactly half of the models
    else:
        assert torch.all(membership_masks[:, num_canaries:].sum(dim=0) == num_models)  # each non-canary in all models


def setup_logging() -> None:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
    )


class ConcatOptimizedCanaryGenerator(baseline_canaries.BaselineCanaryGenerator):
    def __init__(self, directory_manager: DirectoryManager):
        self._directory_manager = directory_manager

    def generate(
        self,
        num_canaries: int,
        image_shape: tuple[int, int, int],
        num_classes: int,
        replaced_images: torch.Tensor,
        replaced_targets: torch.Tensor,
        global_seed: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        results = []
        for canary_idx in range(num_canaries):
            raw_canary = torch.load(self._directory_manager.get_optimized_canary_file(canary_idx))

            image = raw_canary["image"]
            target = raw_canary["target"]

            assert image.shape == image_shape
            assert isinstance(target, int)
            assert target >= 0 and target < num_classes

            results.append((image.to(torch.float32), target))

        canaries = torch.stack([image for image, _ in results])
        targets = torch.tensor([target for _, target in results])

        return canaries, targets
