import os

from torchvision import datasets

from codes.components.worker import ByzantineWorker


# class LabelflippingMNIST(datasets.MNIST):
#     def __getitem__(self, index):
#         img, target = super(LabelflippingMNIST, self).__getitem__(index)
#         target = 9 - target
#         return img, target
#
#     @property
#     def raw_folder(self):
#         return os.path.join(self.root, "MNIST", "raw")
#
#     @property
#     def processed_folder(self):
#         return os.path.join(self.root, "MNIST", "processed")
#
#
# class LabelflippingCIFAR10(datasets.CIFAR10):
#     def __getitem__(self, index):
#         img, target = super(LabelflippingCIFAR10, self).__getitem__(index)
#         target = 9 - target
#         return img, target


class LabelFlippingWorker(ByzantineWorker):
    def __init__(self, revertible_label_transformer, *args, **kwargs):
        """
        Initialize the LabelFlippingWorker instance.

        Args:
            revertible_label_transformer (callable):
                A callable function that defines how labels are transformed.
                E.g., lambda label: 9 - label for reversing labels.
        """
        super().__init__(*args, **kwargs)
        self.revertible_label_transformer = revertible_label_transformer

    def train_epoch_start(self) -> None:
        """
        This method is called at the start of each training epoch. It wraps
        the train_loader_iterator's next method to apply label flipping.
        """
        super().train_epoch_start()
        self.running["train_loader_iterator"].__next__ = self._wrap_iterator(
            self.running["train_loader_iterator"].__next__
        )

    def _wrap_iterator(self, func):
        """
        Wrap an iterator's next function to apply label flipping.

        Args:
            func (function): The original next function.

        Returns:
            function: A wrapped function that applies label flipping.
        """
        def wrapper():
            data, target = func()
            return data, self.revertible_label_transformer(target)

        return wrapper

    def _wrap_metric(self, func):
        """
        Wrap a metric function to apply label flipping.

        Args:
            func (function): The original metric function.

        Returns:
            function: A wrapped function that applies label flipping.
        """
        def wrapper(output, target):
            return func(output, self.revertible_label_transformer(target))

        return wrapper

    def add_metric(self, name, callback):
        """
        Add a metric to the worker, but with label flipping applied.

        Args:
            name (str): Name of the metric.
            callback (function): The metric function to add.

        Raises:
            KeyError: If the metric name is already used.
        """
        if name in self.metrics or name in ["loss", "length"]:
            raise KeyError(f"Metrics ({name}) already added.")

        self.metrics[name] = self._wrap_metric(callback)

    def __str__(self) -> str:
        """Return a string representation of the worker."""
        return "LableFlippingWorker"
