import torch
from torch import nn

class StreamingLDA(nn.Module):
    """
    This is an implementation of the Deep Streaming Linear Discriminant Analysis algorithm for streaming learning.
    """

    def __init__(self, args, test_batch_size=4096, shrinkage_param=1e-4,
                 streaming_update_sigma=True):
        """
        Init function for the SLDA model.
        :param input_shape: feature dimension
        :param num_classes: number of total classes in stream
        :param test_batch_size: batch size for inference
        :param shrinkage_param: value of the shrinkage parameter
        :param streaming_update_sigma: True if sigma is plastic else False
        """

        super(StreamingLDA, self).__init__()

        # SLDA parameters
        # self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = 'cuda'
        self.input_shape = sum(args.backbone_args['h_dims'])
        self.num_classes = args.n_cls
        self.test_batch_size = test_batch_size
        self.shrinkage_param = shrinkage_param
        self.streaming_update_sigma = streaming_update_sigma

        # setup weights for SLDA
        self.muK = torch.zeros((self.num_classes, self.input_shape)).to(self.device)
        self.cK = torch.zeros(self.num_classes).to(self.device)
        self.Sigma = torch.ones((self.input_shape, self.input_shape)).to(self.device)
        self.num_updates = 0
        self.Lambda = torch.zeros_like(self.Sigma).to(self.device)
        self.prev_num_updates = -1

    def fit(self, x, y):
        """
        Fit the SLDA model to a new sample (x,y).
        :param x: a torch tensor of the input data (must be a vector)
        :param y: a torch tensor of the input label
        :return: None
        """
        x = x.to(self.device)
        y = y.long().to(self.device)

        # make sure things are the right shape
        if len(x.shape) < 2:
            x = x.unsqueeze(0)
        if len(y.shape) == 0:
            y = y.unsqueeze(0)

        with torch.no_grad():

            # covariance updates
            if self.streaming_update_sigma:
                x_minus_mu = (x - self.muK[y])
                mult = torch.matmul(x_minus_mu.transpose(1, 0), x_minus_mu)
                delta = mult * self.num_updates / (self.num_updates + 1)
                self.Sigma = (self.num_updates * self.Sigma + delta) / (self.num_updates + 1)

            # update class means
            self.muK[y, :] += (x - self.muK[y, :]) / (self.cK[y] + 1).unsqueeze(1)
            self.cK[y] += 1
            self.num_updates += 1

    def fit_batch(self, x, y):
        x = x.to(self.device)
        y = y.long().to(self.device)
        
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        if len(y.shape) == 0:
            y = y.unsqueeze(0)
        
        unique_classes = torch.unique(y)
        
        with torch.no_grad():
            for cls in unique_classes:
                cls_mask = (y == cls)
                x_cls = x[cls_mask]
                count_cls = x_cls.shape[0]
                
                # Mean update
                old_mean = self.muK[cls]
                new_mean = old_mean + (x_cls.mean(dim=0) - old_mean) * (count_cls / (self.cK[cls] + count_cls))
                
                # Covariance update
                if self.streaming_update_sigma:
                    x_minus_mu = x_cls - old_mean
                    delta = (x_minus_mu.T @ x_minus_mu) * (self.num_updates / (self.num_updates + count_cls))
                    self.Sigma = (self.num_updates * self.Sigma + delta) / (self.num_updates + count_cls)
                
                self.muK[cls] = new_mean
                self.cK[cls] += count_cls
                
            self.num_updates += x.shape[0]

    def forward(self, X, return_probas=False):
        """
        Make predictions on test data X.
        :param X: a torch tensor that contains N data samples (N x d)
        :param return_probas: True if the user would like probabilities instead of predictions returned
        :return: the test predictions or probabilities
        """
        X = X.to(self.device)

        with torch.no_grad():
            # initialize parameters for testing
            num_samples = X.shape[0]
            scores = torch.empty((num_samples, self.num_classes)).to(self.device)
            mb = min(self.test_batch_size, num_samples)

            # compute/load Lambda matrix
            if self.prev_num_updates != self.num_updates:
                # there have been updates to the model, compute Lambda
                #print('\nFirst predict since model update...computing Lambda matrix...')
                Lambda = torch.pinverse(
                    (1 - self.shrinkage_param) * self.Sigma + self.shrinkage_param * torch.eye(self.input_shape).to(
                        self.device))
                self.Lambda = Lambda
                self.prev_num_updates = self.num_updates
            else:
                Lambda = self.Lambda

            # parameters for predictions
            M = self.muK.transpose(1, 0)
            W = torch.matmul(Lambda, M)
            c = 0.5 * torch.sum(M * W, dim=0)

            # loop in mini-batches over test samples
            for i in range(0, num_samples, mb):
                start = min(i, num_samples - mb)
                end = i + mb
                x = X[start:end]
                scores[start:end, :] = torch.matmul(x, W) - c

            # return predictions or probabilities
            if not return_probas:
                return scores
            else:
                return torch.softmax(scores, dim=1)
