import torch.nn as nn
import torch 
import numpy as np

"""
Binary mask sampling based on: https://github.com/gzerveas/mvts_transformer
"""

def tst_noise_masking(X, device, r = 0.15, lm = 3):
    # Create an array with ones in the shape of the input tensor X
    mask = torch.ones_like(X).int().to(device)
    # draw a random feature column for every sample in the batch and set it to 0
    for batch in range(X.size(0)):
        mask[batch,:,:] = torch.Tensor(noise_mask(X[batch,:,:], masking_ratio=r, lm=lm))
    return mask

def noise_mask(X, masking_ratio, lm=3, mode='separate', distribution='geometric', exclude_feats=None):
    
    if exclude_feats is not None:
        exclude_feats = set(exclude_feats)

    if distribution == 'geometric':  # stateful (Markov chain)
        if mode == 'separate':  # each variable (feature) is independent
            mask = np.ones(X.shape, dtype=bool)
            for m in range(X.shape[1]):  # feature dimension
                if exclude_feats is None or m not in exclude_feats:
                    mask[:, m] = geom_noise_mask_single(X.shape[0], lm, masking_ratio)  # time dimension
        else:  # replicate across feature dimension (mask all variables at the same positions concurrently)
            mask = np.tile(np.expand_dims(geom_noise_mask_single(X.shape[0], lm, masking_ratio), 1), X.shape[1])
    else:  # each position is independent Bernoulli with p = 1 - masking_ratio
        if mode == 'separate':
            mask = np.random.choice(np.array([True, False]), size=X.shape, replace=True,
                                    p=(1 - masking_ratio, masking_ratio))
        else:
            mask = np.tile(np.random.choice(np.array([True, False]), size=(X.shape[0], 1), replace=True,
                                            p=(1 - masking_ratio, masking_ratio)), X.shape[1])

    return mask

def geom_noise_mask_single(L, lm, masking_ratio):

    keep_mask = np.ones(L, dtype=bool)
    p_m = 1 / lm  # probability of each masking sequence stopping. parameter of geometric distribution.
    p_u = p_m * masking_ratio / (1 - masking_ratio)  # probability of each unmasked sequence stopping. parameter of geometric distribution.
    p = [p_m, p_u]

    # Start in state 0 with masking_ratio probability
    state = int(np.random.rand() > masking_ratio)  # state 0 means masking, 1 means not masking
    for i in range(L):
        keep_mask[i] = state  # here it happens that state and masking value corresponding to state are identical
        if np.random.rand() < p[state]:
            state = 1 - state

    return keep_mask

# Masked Loss 
class MaskedMSELoss(nn.Module):

    def __init__(self, reduction: str = 'mean'):

        super().__init__()

        self.reduction = reduction
        self.mse_loss = nn.MSELoss(reduction=self.reduction)

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:

        masked_pred = torch.masked_select(y_pred, ~mask)
        masked_true = torch.masked_select(y_true, ~mask)

        return self.mse_loss(masked_pred, masked_true)

"""
This script provides classes for simulating various types of sensor faults
on PyTorch tensors for pretraining. These tensors are expected to represent
time series data with dimensions (batch_size, sequence_length, num_features).

Each class simulates a specific fault by modifying elements of the
input tensor based on a provided binary mask. The `apply_mask` method in each
class takes the input tensor and a `binary_mask` tensor. Alterations are
applied where the `binary_mask` is False (or 0). See Chapter 4.1 in the paper for more details.

Classes:
    bias_mask: Adds a randomly sampled additive bias to elements indicated by
               the `binary_mask`. The bias for each batch sample is drawn
               uniformly from a range defined during class initialization.
    mean_mask: Zeros out the values of elements indicated by the `binary_mask`.
    performance_degradation_mask: Adds random Gaussian noise to elements
                                   indicated by the `binary_mask`, simulating
                                   heavy noise. The noise
                                   level is provided during class initialization.
"""

# Bias
class bias_mask():
    def __init__(self, device, lower_bound = -1, upper_bound = 1, ):
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.device = device

    def apply_mask(self, X, binary_mask):
        tensor = torch.clone(X).to(self.device)
        add = torch.FloatTensor(tensor.size(0),1, tensor.size(2)).uniform_(self.lower_bound, self.upper_bound).to(self.device)
        tensor = tensor + add * (~binary_mask.bool()).int()
        return tensor

# Mean
class mean_mask():
    def __init__(self, device):
        self.device = device

    def apply_mask(self, X, binary_mask):
        tensor = torch.clone(X).to(self.device)
        tensor = tensor * binary_mask.to(self.device)
        return tensor
    
# Noise
class performance_degradation_mask():
    def __init__(self, device, noise_level = 0.4):
        self.noise_level = noise_level
        self.device = device

    def apply_mask(self, X, binary_mask):
        tensor = torch.clone(X).to(self.device)
        noise = torch.normal(mean=0, std=self.noise_level, size=(tensor.size(0), tensor.size(1), tensor.size(2))).to(self.device)
        tensor = tensor + noise * (~binary_mask.bool()).int()
        return tensor

