import numpy as np

from dataclasses import dataclass

@dataclass
class SamplingArgs:
    """
    Args:
        param_sampler                           Parameter sampler in the SWIM algorithm. Set to 'random'
                                                to use fixed distributions ("ELM") instead.
        seed                                    Random state to be used in the Dense layers of the
                                                swimnetworks module.
        sample_uniformly                        Whether to use uniform distribution when sampling
                                                from the input space for the SWIM algorithm.
                                                Setting this to 'False' enables our Approximate-SWIM
                                                algorithm.
        elm_bias_start                          Low value of the uniform distribution used in ELM
                                                for sampling the bias parameter.
        elm_bias_end                            High value of the uniform distribution used in ELM
                                                for sampling the bias parameter.
        elm_weight_loc                          Normal distribution mean for sampling weights in ELM
        elm_weight_std                          Normal distribution scale (std. deviation)
                                                for sampling weights in ELM
        resample_duplicates                     Whether to resample parameters that appear twice in
                                                order to try to avoid redundant parameters (same basis functions).
        dtype                                   numpy dtype, should be either np.float32 for single
                                                np.float64 for double precision.
    """
    param_sampler: str = "relu" # 'relu' or 'tanh'
    seed: int = 1947265
    sample_uniformly: bool = False # Enables Approximate-SWIM algorithm when True function values are not available (e.g., Hamiltonian values).
    swim_dy_norm_ord: int | float = np.inf
    elm_bias_start: float = -np.pi # recommended to set according to input data distribution
    elm_bias_end: float = np.pi
    elm_weight_loc: float = 0.0
    elm_weight_std: float = 1.0
    enc_sigma: float = 1.0                  # Random Fourier Feature parameter for the node and edge encoders
    msg_sigma: float = 1.0                  # Random Fourier Feature parameter for the message encoder
    resample_duplicates: bool = True
    dtype: type = np.float32

    def __post_init__(self):
        if self.param_sampler == "random":
            assert self.sample_uniformly, "param_sampler is 'random' but sample_uniformly is set to 'True'. Ensure it is set to 'False' for sanity check, otherwise if you want to use Approximate-SWIM set param_sampler to either 'tanh' or 'relu'."
        assert self.elm_bias_end >= self.elm_bias_start, "ELM bias range is violated."
        assert self.elm_weight_std >= 0.0, "ELM weight (normal distribution) std must be non-negative."
        assert self.dtype == np.float32 or self.dtype == np.float64, "Only single and double are supported."

