
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import numpy as np


import numpy as np
import torch
import random
import time
from abc import ABCMeta, abstractmethod
from scipy.stats import binom
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted
import importlib

def _instantiate_class(module_name: str, class_name: str):
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_()

class LinearBlock(torch.nn.Module):
    """Linear Block"""
    def __init__(self, in_channels, out_channels, mid_channels=None,
                 activation='Tanh', bias=False, batch_norm=False,
                 skip_connection=None, dropout=None):
        super(LinearBlock, self).__init__()

        self.skip_connection = skip_connection

        self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias)

        # Tanh, ReLU, LeakyReLU, Sigmoid
        if activation is not None:
            self.act_layer = _instantiate_class("torch.nn.modules.activation", activation)
        else:
            self.act_layer = torch.nn.Identity()

        self.dropout = dropout
        if dropout is not None:
            self.dropout_layer = torch.nn.Dropout(p=dropout)

        self.batch_norm = batch_norm
        if batch_norm is True:
            dim = out_channels if mid_channels is None else mid_channels
            self.bn_layer = torch.nn.BatchNorm1d(dim, affine=bias)

    def forward(self, x):
        x1 = self.linear(x)
        x1 = self.act_layer(x1)

        if self.batch_norm is True:
            x1 = self.bn_layer(x1)

        if self.dropout is not None:
            x1 = self.dropout_layer(x1)

        if self.skip_connection == 'concat':
            x1 = torch.cat([x, x1], axis=1)

        return x1


class MLPnet(torch.nn.Module):
    """MLP-based Representation Network"""
    def __init__(self, n_features, n_hidden='500,100', n_output=20, mid_channels=None,
                 activation='ReLU', bias=False, batch_norm=False,
                 skip_connection=None, dropout=None):
        super(MLPnet, self).__init__()
        self.skip_connection = skip_connection
        self.n_output = n_output

        if type(n_hidden)==int:
            n_hidden = [n_hidden]
        if type(n_hidden)==str:
            n_hidden = n_hidden.split(',')
            n_hidden = [int(a) for a in n_hidden]
        num_layers = len(n_hidden)

        # for only use one kind of activation layer
        if type(activation) == str:
            activation = [activation] * num_layers
            activation.append(None)

        assert len(activation) == len(n_hidden)+1, 'activation and n_hidden are not matched'

        self.layers = []
        for i in range(num_layers+1):
            in_channels, out_channels = self.get_in_out_channels(i, num_layers, n_features,
                                                                 n_hidden, n_output, skip_connection)
            self.layers += [
                LinearBlock(in_channels, out_channels,
                            mid_channels=mid_channels,
                            bias=bias, batch_norm=batch_norm,
                            activation=activation[i],
                            skip_connection=skip_connection if i != num_layers else 0,
                            dropout=dropout if i !=num_layers else None)
            ]
        self.network = torch.nn.Sequential(*self.layers)

    def forward(self, x):
        x = self.network(x)
        return x

    def get_in_out_channels(self, i, num_layers, n_features, n_hidden, n_output, skip_connection):
        if skip_connection is None:
            in_channels = n_features if i == 0 else n_hidden[i-1]
            out_channels = n_output if i == num_layers else n_hidden[i]
        elif skip_connection == 'concat':
            in_channels = n_features if i == 0 else np.sum(n_hidden[:i])+n_features
            out_channels = n_output if i == num_layers else n_hidden[i]
        else:
            raise NotImplementedError('')
        return in_channels, out_channels


class BaseDeepAD(metaclass=ABCMeta):
    """
    Abstract class for deep outlier detection models

    Parameters
    ----------

    epochs: int, optional (default=100)
        Number of training epochs

    batch_size: int, optional (default=64)
        Number of samples in a mini-batch

    lr: float, optional (default=1e-3)
        Learning rate

    n_ensemble: int or str, optional (default=1)
        Number of ensemble size

    epoch_steps: int, optional (default=-1)
        Maximum steps in an epoch
            - If -1, all the batches will be processed

    prt_steps: int, optional (default=10)
        Number of epoch intervals per printing

    device: str, optional (default='cuda')
        torch device,

    contamination : float in (0., 0.5), optional (default=0.1)
        The amount of contamination of the data set,
        i.e. the proportion of outliers in the data set. Used when fitting to
        define the threshold on the decision function.

    verbose: int, optional (default=1)
        Verbosity mode

    random_state： int, optional (default=42)
        the seed used by the random

    Attributes
    ----------
    decision_scores_ : numpy array of shape (n_samples,)
        The outlier scores of the training data.
        The higher, the more abnormal. Outliers tend to have higher
        scores. This value is available once the detector is fitted.

    threshold_ : float
        The threshold is based on ``contamination``. It is the
        ``n_samples * contamination`` most abnormal samples in
        ``decision_scores_``. The threshold is calculated for generating
        binary outlier labels.

    labels_ : int, either 0 or 1
        The binary labels of the training data. 0 stands for inliers
        and 1 for outliers/anomalies. It is generated by applying
        ``threshold_`` on ``decision_scores_``.

    """
    def __init__(self, model_name, epochs=100, batch_size=64, lr=1e-3,
                 n_ensemble=1,
                 epoch_steps=-1, prt_steps=10,
                 device='cuda', contamination=0.1,
                 verbose=1, random_state=42):
        self.model_name = model_name

        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr

        self.device = device
        self.contamination = contamination

        self.epoch_steps = epoch_steps
        self.prt_steps = prt_steps
        self.verbose = verbose

        self.n_features = -1
        self.n_samples = -1
        self.criterion = None
        self.net = None

        self.n_ensemble = n_ensemble

        self.train_loader = None
        self.test_loader = None

        self.epoch_time = None

        self.train_data = None
        self.train_label = None

        self.decision_scores_ = None
        self.labels_ = None
        self.threshold_ = None

        self.random_state = random_state
        self.set_seed(random_state)
        return

    def fit(self, X, y=None):
        """
        Fit detector. y is ignored in unsupervised methods.

        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples.

        y : numpy array of shape (n_samples, )
            Not used in unsupervised methods, present for API consistency by convention.
            used in (semi-/weakly-) supervised methods

        Returns
        -------
        self : object
            Fitted estimator.
        """

        self.train_data = X
        self.train_label = y
        self.n_samples, self.n_features = X.shape

        if self.verbose >= 1:
            print('Start Training...')

        if self.n_ensemble == 'auto':
            self.n_ensemble = int(np.floor(100 / (np.log(self.n_samples) + self.n_features)) + 1)
        if self.verbose >= 1:
            print(f'ensemble size: {self.n_ensemble}')

        for _ in range(self.n_ensemble):
            self.train_loader, self.net, self.criterion = self.training_prepare(X, y=y)
            self._training()

        if self.verbose >= 1:
            print('Start Inference on the training data...')

        self.decision_scores_ = self.decision_function(X)
        self.labels_ = self._process_decision_scores()

        return self

    def decision_function(self, X):
        """Predict raw anomaly scores of X using the fitted detector.

        The anomaly score of an input sample is computed based on the fitted
        detector. For consistency, outliers are assigned with
        higher anomaly scores.

        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples. Sparse matrices are accepted only
            if they are supported by the base estimator.

        Returns
        -------
        anomaly_scores : numpy array of shape (n_samples,)
            The anomaly score of the input samples.
        """
        testing_n_samples = X.shape[0]
        s_final = np.zeros(testing_n_samples)
        for _ in range(self.n_ensemble):
            self.test_loader = self.inference_prepare(X)
            z, scores = self._inference()
            z, scores = self.decision_function_update(z, scores)
            s_final += scores

        return s_final

    def predict(self, X, return_confidence=False):
        """Predict if a particular sample is an outlier or not.

        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples.

        return_confidence : boolean, optional(default=False)
            If True, also return the confidence of prediction.

        Returns
        -------
        outlier_labels : numpy array of shape (n_samples,)
            For each observation, tells whether
            it should be considered as an outlier according to the
            fitted model. 0 stands for inliers and 1 for outliers.
        confidence : numpy array of shape (n_samples,).
            Only if return_confidence is set to True.
        """

        pred_score = self.decision_function(X)
        prediction = (pred_score > self.threshold_).astype('int').ravel()

        if return_confidence:
            confidence = self._predict_confidence(pred_score)
            return prediction, confidence

        return prediction

    def _predict_confidence(self, test_scores):
        """Predict the model's confidence in making the same prediction
        under slightly different training sets.
        See :cite:`perini2020quantifying`.

        Parameters
        -------
        test_scores : numpy array of shape (n_samples,)
            The anomaly score of the input samples.

        Returns
        -------
        confidence : numpy array of shape (n_samples,)
            For each observation, tells how consistently the model would
            make the same prediction if the training set was perturbed.
            Return a probability, ranging in [0,1].

        """
        n = len(self.decision_scores_)

        count_instances = np.vectorize(lambda x: np.count_nonzero(self.decision_scores_ <= x))
        n_instances = count_instances(test_scores)

        # Derive the outlier probability using Bayesian approach
        posterior_prob = np.vectorize(lambda x: (1 + x) / (2 + n))(n_instances)

        # Transform the outlier probability into a confidence value
        confidence = np.vectorize(
            lambda p: 1 - binom.cdf(n - int(n*self.contamination), n, p)
        )(posterior_prob)
        prediction = (test_scores > self.threshold_).astype('int').ravel()
        np.place(confidence, prediction==0, 1-confidence[prediction == 0])
        return confidence

    def _process_decision_scores(self):
        """Internal function to calculate key attributes:

        - threshold_: used to decide the binary label
        - labels_: binary labels of training data

        Returns
        -------
        self
        """

        self.threshold_ = np.percentile(self.decision_scores_, 100 * (1 - self.contamination))
        self.labels_ = (self.decision_scores_ > self.threshold_).astype('int').ravel()

        self._mu = np.mean(self.decision_scores_)
        self._sigma = np.std(self.decision_scores_)

        return self


    def _training(self):
        optimizer = torch.optim.Adam(self.net.parameters(),
                                     lr=self.lr,
                                     weight_decay=1e-5)

        self.net.train()
        for i in range(self.epochs):
            t1 = time.time()
            total_loss = 0
            cnt = 0
            for batch_x in self.train_loader:
                loss = self.training_forward(batch_x, self.net, self.criterion)
                self.net.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                cnt += 1

                # terminate this epoch when reaching assigned maximum steps per epoch
                if cnt > self.epoch_steps != -1:
                    break

            t = time.time() - t1
            if self.verbose >= 1 and (i == 0 or (i+1) % self.prt_steps == 0):
                print(f'epoch{i+1}, '
                      f'training loss: {total_loss/cnt:.6f}, '
                      f'time: {t:.1f}s')

            if i == 0:
                self.epoch_time = t

            self.epoch_update()

        return

    def _inference(self):
        self.net.eval()
        with torch.no_grad():
            z_lst = []
            score_lst = []
            for batch_x in self.test_loader:
                batch_z, s = self.inference_forward(batch_x, self.net, self.criterion)

                z_lst.append(batch_z)
                score_lst.append(s)

        z = torch.cat(z_lst).data.cpu().numpy()
        scores = torch.cat(score_lst).data.cpu().numpy()

        return z, scores

    @abstractmethod
    def training_forward(self, batch_x, net, criterion):
        """define forward step in training"""
        pass

    @abstractmethod
    def inference_forward(self, batch_x, net, criterion):
        """define forward step in inference"""
        pass

    @abstractmethod
    def training_prepare(self, X, y):
        """define train_loader, net, and criterion"""
        pass

    @abstractmethod
    def inference_prepare(self, X):
        """define test_loader"""
        pass

    def epoch_update(self):
        """for any updating operation after each training epoch"""
        return

    def decision_function_update(self, z, scores):
        """for any updating operation after decision function"""
        return z, scores


    @staticmethod
    def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True




class ICL(BaseDeepAD):
    """
    Anomaly Detection for Tabular Data with Internal Contrastive Learning
     (ICLR'22)
    :cite:`shenkar2022internal`

    Parameters
    ----------
    epochs: int, optional (default=100)
        Number of training epochs

    batch_size: int, optional (default=64)
        Number of samples in a mini-batch

    lr: float, optional (default=1e-3)
        Learning rate

    n_ensemble: int, optional (default=2)
        Number of the ensemble size (make use of the bagging effect)

    rep_dim: int, optional (default=128)
        Dimensionality of the representation space

    hidden_dims: List, str or int, optional (default='100,50')
        Number of neural units in hidden layers
            - If List, each item is a layer
            - If str, neural units of hidden layers are split by comma
            - If int, number of neural units of single hidden layer

    act: str, optional (default='ReLU')
        activation layer name
        choice = ['ReLU', 'LeakyReLU', 'Sigmoid', 'Tanh']

    bias: bool, optional (default=False)
        Additive bias in linear layer

    kernel_size: str or int, optional (default='auto')
        the length of sub-vectors

    temperature: float, optional (default=0.01)
        tau in the cross-entropy function

    max_negatives: int, optional (default=1000)
        Maximum number of negatives (unmatched sub-vectors)

    epoch_steps: int, optional (default=-1)
        Maximum steps in an epoch
            - If -1, all the batches will be processed

    prt_steps: int, optional (default=10)
        Number of epoch intervals per printing

    device: str, optional (default='cuda')
        torch device,

    verbose: int, optional (default=1)
        Verbosity mode

    random_state： int, optional (default=42)
        the seed used by the random

    """
    def __init__(self, epochs=100, batch_size=64, lr=1e-3, n_ensemble='auto',
                 rep_dim=128, hidden_dims='100,50', act='LeakyReLU', bias=False,
                 kernel_size='auto', temperature=0.01, max_negatives=1000,
                 epoch_steps=-1, prt_steps=10, device='cuda',
                 verbose=2, random_state=42):
        super(ICL, self).__init__(
            model_name='ICL', epochs=epochs, batch_size=batch_size,
            lr=lr, n_ensemble=n_ensemble,
            epoch_steps=epoch_steps, prt_steps=prt_steps, device=device,
            verbose=verbose, random_state=random_state
        )

        self.hidden_dims = hidden_dims
        self.rep_dim = rep_dim
        self.act = act
        self.bias = bias

        self.kernel_size = kernel_size
        self.tau = temperature
        self.max_negatives = max_negatives

        return

    def training_prepare(self, X, y):
        train_loader = DataLoader(X, batch_size=self.batch_size,
                                  shuffle=True, pin_memory=True)

        if self.kernel_size == 'auto':
            if self.n_features <= 40:
                self.kernel_size = 2
            elif 40 < self.n_features <= 160:
                self.kernel_size = 10

            # else:
            #     self.kernel_size = self.n_features - 150

            elif 160 < self.n_features <= 240:
                self.kernel_size = self.n_features - 150
            elif 240 < self.n_features <= 480:
                self.kernel_size = self.n_features - 200
            else:
                self.kernel_size = self.n_features - 400

            # elif 320 < self.n_features <= 480:
            #     self.kernel_size = self.n_features - 300
            #
            # else:
            #     self.kernel_size = self.n_features - 450

        if self.verbose >= 1:
            print(f'kernel size: {self.kernel_size}')

        if self.n_features < 3:
            raise ValueError('ICL model cannot handle the data that have less than three features.')

        net = ICLNet(
            n_features=self.n_features,
            kernel_size=self.kernel_size,
            hidden_dims=self.hidden_dims,
            rep_dim=self.rep_dim,
            activation=self.act,
            bias=self.bias
        ).to(self.device)

        criterion = torch.nn.CrossEntropyLoss()

        if self.verbose >= 2:
            print(net)

        return train_loader, net, criterion

    def training_forward(self, batch_x, net, criterion):
        batch_x = batch_x.float().to(self.device)

        # positives are sub-vectors, query are their complements
        positives, query = net(batch_x)

        logit = self.cal_logit(query, positives)
        logit = logit.permute(0, 2, 1)

        correct_class = torch.zeros((logit.shape[0], logit.shape[2]),
                                    dtype=torch.long).to(self.device)
        loss = criterion(logit, correct_class)
        return loss

    def inference_prepare(self, X):
        test_loader = DataLoader(X, batch_size=self.batch_size,
                                 drop_last=False, shuffle=False)
        self.criterion.reduction = 'none'
        return test_loader

    def inference_forward(self, batch_x, net, criterion):
        loss = self.training_forward(batch_x, net, criterion)
        batch_z = batch_x # for consistency
        s = loss.mean(dim=1)
        return batch_z, s

    def cal_logit(self, query, pos):
        n_pos = query.shape[1]
        batch_size = query.shape[0]

        # get negatives
        negative_index = np.random.choice(np.arange(n_pos), min(self.max_negatives, n_pos), replace=False)
        negative = pos.permute(0, 2, 1)[:, :, negative_index]

        pos_multiplication = (query * pos).sum(dim=2).unsqueeze(2)

        neg_multiplication = torch.matmul(query, negative)  # [batch_size, n_neg, n_neg]

        # Removal of the diagonals
        identity_matrix = torch.eye(n_pos).unsqueeze(0).to(self.device)
        identity_matrix = identity_matrix.repeat(batch_size, 1, 1)
        identity_matrix = identity_matrix[:, :, negative_index]

        neg_multiplication.masked_fill_(identity_matrix==1, -float('inf'))

        logit = torch.cat((pos_multiplication, neg_multiplication), dim=2)
        logit = torch.div(logit, self.tau)
        return logit


class ICLNet(torch.nn.Module):
    def __init__(self, n_features, kernel_size,
                 hidden_dims='100,50', rep_dim=64,
                 activation='ReLU', bias=False):
        super(ICLNet, self).__init__()
        self.n_features = n_features
        self.kernel_size = kernel_size

        # get consecutive subspace indices and the corresponding complement indices
        start_idx = np.arange(n_features)[: -kernel_size + 1]  # [0,1,2,...,dim-kernel_size+1]
        self.all_idx = start_idx[:, None] + np.arange(kernel_size)
        self.all_idx_complement = np.array([np.setdiff1d(np.arange(n_features), row)
                                            for row in self.all_idx])

        if type(hidden_dims)==str:
            hidden_dims = hidden_dims.split(',')
            hidden_dims = [int(a) for a in hidden_dims]
        n_layers = len(hidden_dims) # hidden layers
        f_act = ['Tanh']
        for _ in range(n_layers):
            f_act.append(activation)

        self.enc_f_net = MLPnet(
            n_features=n_features-kernel_size,
            n_hidden=hidden_dims,
            n_output=rep_dim,
            mid_channels=len(self.all_idx),
            batch_norm=True,
            activation=f_act,
            bias=bias,
        )

        hidden_dims2 = [int(0.5*h) for h in hidden_dims]
        g_act = []
        for _ in range(n_layers+1):
            g_act.append(activation)

        self.enc_g_net = MLPnet(
            n_features=kernel_size,
            n_hidden=hidden_dims2,
            n_output=rep_dim,
            mid_channels=len(self.all_idx),
            batch_norm=True,
            activation=g_act,
            bias=bias,
        )

        return

    def forward(self, x):
        x1, x2 = self.positive_matrix_builder(data=x)
        x1 = self.enc_g_net(x1)
        x2 = self.enc_f_net(x2)
        x1 = F.normalize(x1)
        x2 = F.normalize(x2)
        return x1, x2

    def positive_matrix_builder(self, data):
        """
        Generate matrix of sub-vectors and matrix of complement vectors (positive pairs)

        Parameters
        ----------
        data: torch.Tensor shape (n_samples, n_features), required
            The input data.

        Returns
        -------
        matrix: torch.Tensor of shape [n_samples, number of sub-vectors, kernel_size]
            Derived sub-vectors.

        complement_matrix: torch.Tensor of shape [n_samples, number of sub-vectors, n_features-kernel_size]
            Complement vector of derived sub-vectors.

        """
        dim = self.n_features

        data = torch.unsqueeze(data, 1)  # [size, 1, dim]
        data = data.repeat(1, dim, 1)  # [size, dim, dim]

        matrix = data[:, np.arange(self.all_idx.shape[0])[:, None], self.all_idx]
        complement_matrix = data[:, np.arange(self.all_idx.shape[0])[:, None], self.all_idx_complement]

        return matrix, complement_matrix


def main():
    model=ICL()
    data=np.zeros([10,20])
    model.fit(data)
main()