from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import numpy as np
import numpy as np
import torch
import random
import time
from abc import ABCMeta, abstractmethod
from scipy.stats import binom
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted
import importlib
def _instantiate_class(module_name: str, class_name: str):
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_()

class LinearBlock(torch.nn.Module):
    """Linear Block"""
    def __init__(self, in_channels, out_channels, mid_channels=None,
                 activation='Tanh', bias=False, batch_norm=False,
                 skip_connection=None, dropout=None):
        super(LinearBlock, self).__init__()

        self.skip_connection = skip_connection

        self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias)

        # Tanh, ReLU, LeakyReLU, Sigmoid
        if activation is not None:
            self.act_layer = _instantiate_class("torch.nn.modules.activation", activation)
        else:
            self.act_layer = torch.nn.Identity()

        self.dropout = dropout
        if dropout is not None:
            self.dropout_layer = torch.nn.Dropout(p=dropout)

        self.batch_norm = batch_norm
        if batch_norm is True:
            dim = out_channels if mid_channels is None else mid_channels
            self.bn_layer = torch.nn.BatchNorm1d(dim, affine=bias)

    def forward(self, x):
        x1 = self.linear(x)
        x1 = self.act_layer(x1)

        if self.batch_norm is True:
            x1 = self.bn_layer(x1)

        if self.dropout is not None:
            x1 = self.dropout_layer(x1)

        if self.skip_connection == 'concat':
            x1 = torch.cat([x, x1], axis=1)

        return x1


class MLPnet(torch.nn.Module):
    """MLP-based Representation Network"""
    def __init__(self, n_features, n_hidden='500,100', n_output=20, mid_channels=None,
                 activation='ReLU', bias=False, batch_norm=False,
                 skip_connection=None, dropout=None):
        super(MLPnet, self).__init__()
        self.skip_connection = skip_connection
        self.n_output = n_output

        if type(n_hidden)==int:
            n_hidden = [n_hidden]
        if type(n_hidden)==str:
            n_hidden = n_hidden.split(',')
            n_hidden = [int(a) for a in n_hidden]
        num_layers = len(n_hidden)

        # for only use one kind of activation layer
        if type(activation) == str:
            activation = [activation] * num_layers
            activation.append(None)

        assert len(activation) == len(n_hidden)+1, 'activation and n_hidden are not matched'

        self.layers = []
        for i in range(num_layers+1):
            in_channels, out_channels = self.get_in_out_channels(i, num_layers, n_features,
                                                                 n_hidden, n_output, skip_connection)
            self.layers += [
                LinearBlock(in_channels, out_channels,
                            mid_channels=mid_channels,
                            bias=bias, batch_norm=batch_norm,
                            activation=activation[i],
                            skip_connection=skip_connection if i != num_layers else 0,
                            dropout=dropout if i !=num_layers else None)
            ]
        self.network = torch.nn.Sequential(*self.layers)

    def forward(self, x):
        x = self.network(x)
        return x

    def get_in_out_channels(self, i, num_layers, n_features, n_hidden, n_output, skip_connection):
        if skip_connection is None:
            in_channels = n_features if i == 0 else n_hidden[i-1]
            out_channels = n_output if i == num_layers else n_hidden[i]
        elif skip_connection == 'concat':
            in_channels = n_features if i == 0 else np.sum(n_hidden[:i])+n_features
            out_channels = n_output if i == num_layers else n_hidden[i]
        else:
            raise NotImplementedError('')
        return in_channels, out_channels


class BaseDeepAD(metaclass=ABCMeta):
    """
    Abstract class for deep outlier detection models

    Parameters
    ----------

    epochs: int, optional (default=100)
        Number of training epochs

    batch_size: int, optional (default=64)
        Number of samples in a mini-batch

    lr: float, optional (default=1e-3)
        Learning rate

    n_ensemble: int or str, optional (default=1)
        Number of ensemble size

    epoch_steps: int, optional (default=-1)
        Maximum steps in an epoch
            - If -1, all the batches will be processed

    prt_steps: int, optional (default=10)
        Number of epoch intervals per printing

    device: str, optional (default='cuda')
        torch device,

    contamination : float in (0., 0.5), optional (default=0.1)
        The amount of contamination of the data set,
        i.e. the proportion of outliers in the data set. Used when fitting to
        define the threshold on the decision function.

    verbose: int, optional (default=1)
        Verbosity mode

    random_state： int, optional (default=42)
        the seed used by the random

    Attributes
    ----------
    decision_scores_ : numpy array of shape (n_samples,)
        The outlier scores of the training data.
        The higher, the more abnormal. Outliers tend to have higher
        scores. This value is available once the detector is fitted.

    threshold_ : float
        The threshold is based on ``contamination``. It is the
        ``n_samples * contamination`` most abnormal samples in
        ``decision_scores_``. The threshold is calculated for generating
        binary outlier labels.

    labels_ : int, either 0 or 1
        The binary labels of the training data. 0 stands for inliers
        and 1 for outliers/anomalies. It is generated by applying
        ``threshold_`` on ``decision_scores_``.

    """
    def __init__(self, model_name, epochs=100, batch_size=64, lr=1e-3,
                 n_ensemble=1,
                 epoch_steps=-1, prt_steps=10,
                 device='cuda', contamination=0.1,
                 verbose=1, random_state=42):
        self.model_name = model_name

        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr

        self.device = device
        self.contamination = contamination

        self.epoch_steps = epoch_steps
        self.prt_steps = prt_steps
        self.verbose = verbose

        self.n_features = -1
        self.n_samples = -1
        self.criterion = None
        self.net = None

        self.n_ensemble = n_ensemble

        self.train_loader = None
        self.test_loader = None

        self.epoch_time = None

        self.train_data = None
        self.train_label = None

        self.decision_scores_ = None
        self.labels_ = None
        self.threshold_ = None

        self.random_state = random_state
        self.set_seed(random_state)
        return

    def fit(self, X, y=None):
        """
        Fit detector. y is ignored in unsupervised methods.

        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples.

        y : numpy array of shape (n_samples, )
            Not used in unsupervised methods, present for API consistency by convention.
            used in (semi-/weakly-) supervised methods

        Returns
        -------
        self : object
            Fitted estimator.
        """

        self.train_data = X
        self.train_label = y
        self.n_samples, self.n_features = X.shape

        if self.verbose >= 1:
            print('Start Training...')

        if self.n_ensemble == 'auto':
            self.n_ensemble = int(np.floor(100 / (np.log(self.n_samples) + self.n_features)) + 1)
        if self.verbose >= 1:
            print(f'ensemble size: {self.n_ensemble}')

        for _ in range(self.n_ensemble):
            self.train_loader, self.net, self.criterion = self.training_prepare(X, y=y)
            self._training()

        if self.verbose >= 1:
            print('Start Inference on the training data...')

        self.decision_scores_ = self.decision_function(X)
        self.labels_ = self._process_decision_scores()

        return self

    def decision_function(self, X):
        """Predict raw anomaly scores of X using the fitted detector.

        The anomaly score of an input sample is computed based on the fitted
        detector. For consistency, outliers are assigned with
        higher anomaly scores.

        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples. Sparse matrices are accepted only
            if they are supported by the base estimator.

        Returns
        -------
        anomaly_scores : numpy array of shape (n_samples,)
            The anomaly score of the input samples.
        """
        testing_n_samples = X.shape[0]
        s_final = np.zeros(testing_n_samples)
        for _ in range(self.n_ensemble):
            self.test_loader = self.inference_prepare(X)
            z, scores = self._inference()
            z, scores = self.decision_function_update(z, scores)
            s_final += scores

        return s_final

    def predict(self, X, return_confidence=False):
        """Predict if a particular sample is an outlier or not.

        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples.

        return_confidence : boolean, optional(default=False)
            If True, also return the confidence of prediction.

        Returns
        -------
        outlier_labels : numpy array of shape (n_samples,)
            For each observation, tells whether
            it should be considered as an outlier according to the
            fitted model. 0 stands for inliers and 1 for outliers.
        confidence : numpy array of shape (n_samples,).
            Only if return_confidence is set to True.
        """

        pred_score = self.decision_function(X)
        prediction = (pred_score > self.threshold_).astype('int').ravel()

        if return_confidence:
            confidence = self._predict_confidence(pred_score)
            return prediction, confidence

        return prediction

    def _predict_confidence(self, test_scores):
        """Predict the model's confidence in making the same prediction
        under slightly different training sets.
        See :cite:`perini2020quantifying`.

        Parameters
        -------
        test_scores : numpy array of shape (n_samples,)
            The anomaly score of the input samples.

        Returns
        -------
        confidence : numpy array of shape (n_samples,)
            For each observation, tells how consistently the model would
            make the same prediction if the training set was perturbed.
            Return a probability, ranging in [0,1].

        """
        n = len(self.decision_scores_)

        count_instances = np.vectorize(lambda x: np.count_nonzero(self.decision_scores_ <= x))
        n_instances = count_instances(test_scores)

        # Derive the outlier probability using Bayesian approach
        posterior_prob = np.vectorize(lambda x: (1 + x) / (2 + n))(n_instances)

        # Transform the outlier probability into a confidence value
        confidence = np.vectorize(
            lambda p: 1 - binom.cdf(n - int(n*self.contamination), n, p)
        )(posterior_prob)
        prediction = (test_scores > self.threshold_).astype('int').ravel()
        np.place(confidence, prediction==0, 1-confidence[prediction == 0])
        return confidence

    def _process_decision_scores(self):
        """Internal function to calculate key attributes:

        - threshold_: used to decide the binary label
        - labels_: binary labels of training data

        Returns
        -------
        self
        """

        self.threshold_ = np.percentile(self.decision_scores_, 100 * (1 - self.contamination))
        self.labels_ = (self.decision_scores_ > self.threshold_).astype('int').ravel()

        self._mu = np.mean(self.decision_scores_)
        self._sigma = np.std(self.decision_scores_)

        return self


    def _training(self):
        optimizer = torch.optim.Adam(self.net.parameters(),
                                     lr=self.lr,
                                     weight_decay=1e-5)

        self.net.train()
        for i in range(self.epochs):
            t1 = time.time()
            total_loss = 0
            cnt = 0
            for batch_x in self.train_loader:
                loss = self.training_forward(batch_x, self.net, self.criterion)
                self.net.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                cnt += 1

                # terminate this epoch when reaching assigned maximum steps per epoch
                if cnt > self.epoch_steps != -1:
                    break

            t = time.time() - t1
            if self.verbose >= 1 and (i == 0 or (i+1) % self.prt_steps == 0):
                print(f'epoch{i+1}, '
                      f'training loss: {total_loss/cnt:.6f}, '
                      f'time: {t:.1f}s')

            if i == 0:
                self.epoch_time = t

            self.epoch_update()

        return

    def _inference(self):
        self.net.eval()
        with torch.no_grad():
            z_lst = []
            score_lst = []
            for batch_x in self.test_loader:
                batch_z, s = self.inference_forward(batch_x, self.net, self.criterion)

                z_lst.append(batch_z)
                score_lst.append(s)

        z = torch.cat(z_lst).data.cpu().numpy()
        scores = torch.cat(score_lst).data.cpu().numpy()

        return z, scores

    @abstractmethod
    def training_forward(self, batch_x, net, criterion):
        """define forward step in training"""
        pass

    @abstractmethod
    def inference_forward(self, batch_x, net, criterion):
        """define forward step in inference"""
        pass

    @abstractmethod
    def training_prepare(self, X, y):
        """define train_loader, net, and criterion"""
        pass

    @abstractmethod
    def inference_prepare(self, X):
        """define test_loader"""
        pass

    def epoch_update(self):
        """for any updating operation after each training epoch"""
        return

    def decision_function_update(self, z, scores):
        """for any updating operation after decision function"""
        return z, scores


    @staticmethod
    def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

class NeuTraL(BaseDeepAD):
    """
    Neural Transformation Learning-based Anomaly Detection (ICML'21)

    """
    def __init__(self, epochs=100, batch_size=64, lr=1e-3,
                 n_trans=11, trans_type='residual', temp=0.1,
                 rep_dim=128, hidden_dims='100,50', trans_hidden_dims=50,
                 act='LeakyReLU', bias=False,
                 epoch_steps=-1, prt_steps=10, device='cuda',
                 verbose=1, random_state=42):
        super(NeuTraL, self).__init__(
            model_name='NeuTraL', epochs=epochs, batch_size=batch_size, lr=lr,
            epoch_steps=epoch_steps, prt_steps=prt_steps, device=device,
            verbose=verbose, random_state=random_state
        )

        self.n_trans = n_trans
        self.trans_type = trans_type
        self.temp = temp

        self.trans_hidden_dims = trans_hidden_dims
        self.enc_hidden_dims = hidden_dims
        self.rep_dim = rep_dim
        self.act = act
        self.bias = bias
        return

    def training_prepare(self, X, y):
        train_loader = DataLoader(X, batch_size=self.batch_size, shuffle=True)

        net = TabNeutralADNet(
            n_features=self.n_features,
            n_trans=self.n_trans,
            trans_type=self.trans_type,
            enc_hidden_dims=self.enc_hidden_dims,
            trans_hidden_dims=self.trans_hidden_dims,
            activation=self.act,
            bias=self.bias,
            rep_dim=self.rep_dim,
            device=self.device
        )

        criterion = DCL(temperature=self.temp)

        if self.verbose >=2:
            print(net)

        return train_loader, net, criterion

    def inference_prepare(self, X):
        test_loader = DataLoader(X, batch_size=self.batch_size, drop_last=False, shuffle=False)
        self.criterion.reduction = 'none'
        return test_loader

    def training_forward(self, batch_x, net, criterion):
        batch_x = batch_x.float().to(self.device)
        z = net(batch_x)
        loss = criterion(z)
        return loss

    def inference_forward(self, batch_x, net, criterion):
        batch_x = batch_x.float().to(self.device)
        batch_z = net(batch_x)
        s = criterion(batch_z)
        return batch_z, s


class TabNeutralADNet(torch.nn.Module):
    """
    network class of NeuTraL for tabular data

    Parameters
    ----------
    n_features: int
        dimensionality of input data

    n_trans: int
        the number of transformation times

    trans_type: str, default='residual'
        transformation type

    enc_hidden_dims: list or str or int
        the number of neural units of hidden layers in encoder net

    trans_hidden_dims: list or str or int
        the number of neural units of hidden layers in transformation net

    rep_dim: int
        representation dimensionality

    activation: str
        activation layer name

    device: str
        device
    """
    def __init__(self, n_features, n_trans=11, trans_type='residual',
                 enc_hidden_dims='24,24,24,24', trans_hidden_dims=24,
                 rep_dim=24,
                 activation='ReLU',
                 bias=False,
                 device='cuda'):
        super(TabNeutralADNet, self).__init__()

        self.enc = MLPnet(
            n_features=n_features,
            n_hidden=enc_hidden_dims,
            n_output=rep_dim,
            activation=activation,
            bias=bias,
            batch_norm=False
        )
        self.trans = torch.nn.ModuleList(
            [MLPnet(n_features=n_features,
                    n_hidden=trans_hidden_dims,
                    n_output=n_features,
                    activation=activation,
                    bias=bias,
                    batch_norm=False) for _ in range(n_trans)]
        )

        self.trans.to(device)
        self.enc.to(device)

        self.n_trans = n_trans
        self.trans_type = trans_type
        self.z_dim = rep_dim

    def forward(self, x):
        x_transform = torch.empty(x.shape[0], self.n_trans, x.shape[-1]).to(x)

        for i in range(self.n_trans):
            mask = self.trans[i](x)
            if self.trans_type == 'forward':
                x_transform[:, i] = mask
            elif self.trans_type == 'mul':
                mask = torch.sigmoid(mask)
                x_transform[:, i] = mask * x
            elif self.trans_type == 'residual':
                x_transform[:, i] = mask + x

        x_cat = torch.cat([x.unsqueeze(1), x_transform], 1)
        zs = self.enc(x_cat.reshape(-1, x.shape[-1]))
        zs = zs.reshape(x.shape[0], self.n_trans+1, self.z_dim)
        return zs


class DCL(torch.nn.Module):
    def __init__(self, temperature=0.1, reduction='mean'):
        super(DCL, self).__init__()
        self.temp = temperature
        self.reduction = reduction

    def forward(self, z):
        z = F.normalize(z, p=2, dim=-1)
        z_ori = z[:, 0]  # n,z
        z_trans = z[:, 1:]  # n,k-1, z
        batch_size, n_trans, z_dim = z.shape

        sim_matrix = torch.exp(torch.matmul(z, z.permute(0, 2, 1) / self.temp))  # n,k,k
        mask = (torch.ones_like(sim_matrix).to(z) - torch.eye(n_trans).unsqueeze(0).to(z)).bool()
        sim_matrix = sim_matrix.masked_select(mask).view(batch_size, n_trans, -1)
        trans_matrix = sim_matrix[:, 1:].sum(-1)  # n,k-1

        pos_sim = torch.exp(torch.sum(z_trans * z_ori.unsqueeze(1), -1) / self.temp) # n,k-1
        K = n_trans - 1
        scale = 1 / np.abs(K*np.log(1.0 / K))

        loss = (torch.log(trans_matrix) - torch.log(pos_sim)) * scale
        loss = loss.sum(1)

        reduction = self.reduction
        if reduction == 'mean':
            return torch.mean(loss)
        elif reduction == 'sum':
            return torch.sum(loss)
        elif reduction == 'none':
            return loss

        return loss
def main():
    model=NeuTraL()
    data=np.zeros([10,20])
    model.fit(data)
main()