from typing import Optional, Tuple, Union

import torch

import numpy as np
import tqdm

from models.common import AnomalyDetector
from data.statistics import compute_feature_mean_std


class ThresholdAnomalyDetector(AnomalyDetector):
    
    def __init__(self, feature_index : int, lower_threshold : Optional[float] = None,
                 upper_threshold : Optional[float] = None, input_shape : str = 'btf',
                 device: Union[str, torch.device] = 'cpu', std_factor : float = 3):

        super(ThresholdAnomalyDetector, self).__init__()

        self.lower_threshold = -np.inf if lower_threshold is None else lower_threshold
        self.upper_threshold = np.inf if upper_threshold is None else upper_threshold
        self.feature         = feature_index
        self.input_shape     = input_shape
        self.device          = device
        self.std_factor      = std_factor

    def fit(self, dataset: torch.utils.data.DataLoader) -> None:

        if np.isinf(self.lower_threshold) and np.isinf(self.upper_threshold):

            for b_inputs, b_targets in tqdm.tqdm(dataset):

                b_inputs = b_inputs[0].to(self.device)

                feature_mean = torch.mean(b_inputs, [self.input_shape.index('t'), self.input_shape.index('b')])
                feature_std  = torch.std(b_inputs, [self.input_shape.index('t'), self.input_shape.index('b')])

                self.lower_threshold = feature_mean[self.feature] - self.std_factor * feature_std[self.feature]
                self.upper_threshold = feature_mean[self.feature] + self.std_factor * feature_std[self.feature]

    def compute_online_anomaly_score(self, inputs: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        # Input of shape (B, T, D) or (T, B, D), output of shape (B,)

        predictions = torch.logical_or(inputs[0] < self.lower_threshold, inputs[0] > self.upper_threshold).int()

        return predictions[:, -1, self.feature] if self.input_shape[0] == 'b' else predictions[-1, :, self.feature]

    def compute_offline_anomaly_score(self, inputs: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        raise NotImplementedError

    def format_online_targets(self, targets: Tuple[torch.Tensor, ...]) -> torch.Tensor:

        # Input of shape (B, T) or (T, B), output of shape (B)
        target, = targets

        # Just return the last label of the window
        return target[:, -1] if self.input_shape[0] == 'b' else target[-1]


class MeanDistanceAnomalyDetector(AnomalyDetector):

    def __init__(self, feature_index: int, input_shape: str = 'btf', device: Union[str, torch.device] = 'cpu'):

        super(MeanDistanceAnomalyDetector, self).__init__()

        self.feature     = feature_index
        self.input_shape = input_shape
        self.device      = device
        self.mean        = None

    def fit(self, dataset: torch.utils.data.DataLoader) -> None:

        if self.mean is None:
            self.mean, _, _, _ = compute_feature_mean_std(dataset.dataset)

    def compute_online_anomaly_score(self, inputs: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        # Input of shape (B, T, D) or (T, B, D), output of shape (B,)

        scores = torch.square(inputs[0] - self.mean)

        return scores[:, -1, self.feature] if self.input_shape[0] == 'b' else scores[-1, :, self.feature]

    def compute_offline_anomaly_score(self, inputs: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        raise NotImplementedError

    def format_online_targets(self, targets: Tuple[torch.Tensor, ...]) -> torch.Tensor:

        # Input of shape (B, T) or (T, B), output of shape (B)
        target, = targets

        # Just return the last label of the window
        return target[:, -1] if self.input_shape[0] == 'b' else target[-1]


