from enum import Enum

import torch as t


class CompensatingFeatureBiasPosition(Enum):
    NONE = "none"
    MAGNITUDE = "magnitude"
    BOTH = "both"


class ResamplingType(Enum):
    NONE = "none"
    NAIVE = "naive"
    FANCY = "fancy"


class AutoEncoderType(Enum):
    VANILLA = "vanilla"
    TOPM = "topm"
    ZIPF_TOPM = "zipf_topm"
    TOPK = "topk"
    MUTUAL_CHOICE = "mutual_choice"

    @property
    def vanilla(self) -> bool:
        return self in [
            AutoEncoderType.VANILLA,
            AutoEncoderType.TOPM,
            AutoEncoderType.ZIPF_TOPM,
            AutoEncoderType.TOPK,
            AutoEncoderType.MUTUAL_CHOICE,
        ]

    @property
    def topm(self) -> bool:
        return self in [AutoEncoderType.TOPM, AutoEncoderType.ZIPF_TOPM]

    @property
    def secondary_loss(self) -> bool:
        return self in [
            AutoEncoderType.ZIPF_TOPM,
            AutoEncoderType.TOPM,
            AutoEncoderType.MUTUAL_CHOICE,
        ]

    @property
    def generalised_topk(self) -> bool:
        return self in [
            AutoEncoderType.TOPK,
            AutoEncoderType.ZIPF_TOPM,
            AutoEncoderType.TOPM,
            AutoEncoderType.MUTUAL_CHOICE,
        ]


class SchedulerType(Enum):
    COSINE = "cosine"
    CUSTOM = "custom"
    DECREASING_K_M = "variable_k_m"
    NONE = "none"


class PreprocessStrategy(Enum):
    """
    Enumeration class representing different preprocess strategies for data.

    Attributes:
        NONE: No preprocessing.
        SHIFT_ZERO: Have a neuron bias but initialise it to zero.
        SHIFT_MEDIAN: Have a neuron bias initialised to the geometric median.
        SHIFT_AND_SCALE_ZERO: Have a neuron bias and scale factor but initialise them to 0 and 1, respectively.
        SHIFT_AND_SCALE_MEDIAN_ONLY: Initialise the neuron bias at median and scale factor to 1.
        SHIFT_AND_SCALE_STANDARD: Initialise the neuron bias at median and scale factor at MAD (Median Absolute Deviation).
        SHIFT_AND_NORMALISE: Initialise the neuron bias at median and then scale with personalised normalisation factor.
    """

    NONE = "none"
    SHIFT_ZERO = "shift_zero"
    SHIFT_MEDIAN = "shift_median"
    SHIFT_AND_SCALE_ZERO = "shift_and_scale_zero"
    SHIFT_AND_SCALE_MEDIAN_ONLY = "shift_and_scale_median_only"
    SHIFT_AND_SCALE_STANDARD = "shift_and_scale_standard"
    SHIFT_AND_NORMALISE = "shift_and_normalise"

    @property
    def median_shift(self) -> bool:
        """
        Check if the preprocess strategy involves shifting the data using median.

        Returns:
            bool: True if the strategy involves median shift, False otherwise.
        """
        return self in [
            PreprocessStrategy.SHIFT_MEDIAN,
            PreprocessStrategy.SHIFT_AND_SCALE_MEDIAN_ONLY,
            PreprocessStrategy.SHIFT_AND_SCALE_STANDARD,
        ]

    @property
    def shift(self) -> bool:
        """
        Check if the preprocess strategy involves shifting the data.

        Returns:
            bool: True if the strategy involves shifting, False otherwise.
        """
        return self in [
            PreprocessStrategy.SHIFT_ZERO,
            PreprocessStrategy.SHIFT_MEDIAN,
            PreprocessStrategy.SHIFT_AND_SCALE_ZERO,
            PreprocessStrategy.SHIFT_AND_SCALE_MEDIAN_ONLY,
            PreprocessStrategy.SHIFT_AND_SCALE_STANDARD,
        ]

    @property
    def scale_MAD(self) -> bool:
        """
        Check if the preprocess strategy involves scaling the data using Median Absolute Deviation (MAD).

        Returns:
            bool: True if the strategy involves scaling using MAD, False otherwise.
        """
        return self in [
            PreprocessStrategy.SHIFT_AND_SCALE_STANDARD,
        ]

    @property
    def scale(self) -> bool:
        """
        Check if the preprocess strategy involves scaling the data.

        Returns:
            bool: True if the strategy involves scaling, False otherwise.
        """
        return self in [
            PreprocessStrategy.SHIFT_AND_SCALE_ZERO,
            PreprocessStrategy.SHIFT_AND_SCALE_MEDIAN_ONLY,
            PreprocessStrategy.SHIFT_AND_SCALE_STANDARD,
        ]

    @property
    def normalise(self) -> bool:
        """
        Check if the preprocess strategy involves normalizing the data.

        Returns:
            bool: True if the strategy involves normalization, False otherwise.
        """
        return self in [
            PreprocessStrategy.SHIFT_AND_NORMALISE,
        ]


class AutocastDtype(Enum):
    """
    Enumeration class representing different autocast data types.

    Attributes:
        FP32: 32-bit floating point.
        FP16: 16-bit floating point.
    """

    NONE = "none"
    BF16 = "bf16"
    FP16 = "fp16"

    @property
    def is_enabled(self) -> bool:
        """
        Check if the autocast data type is enabled.

        Returns:
            bool: True unless the dtype is None.
        """
        return self != AutocastDtype.NONE

    @property
    def dtype(self) -> t.dtype:
        if self == AutocastDtype.BF16:
            return t.bfloat16
        elif self == AutocastDtype.FP16:
            return t.float16
        else:
            return t.float32


class TopMFlavour(Enum):
    UNIFORM = "uniform"
    ZIPF = "zipf"
