# Implement a Kernel SVM for classification.
import torch
import torch.nn as nn
from ML_Models.base_model import BaseMLModel
from ML_Models.KernelRidgeRegression.model import squared_exp_kernel, polynomial_kernel, linear_kernel
import typing as tp
from torch.optim import RMSprop, Adam, SGD
from sklearn.svm import SVC
from Tools.jackknife import *

class BaseSVM(BaseMLModel):
    """ Base class for KernelSVM and KernelLeastSquaresSVM. """
    def __init__(self, input_dim: int,  kernel_str = "rbf",
                 weighted_model: bool = False, train_set_size=None, C_param = 1.0):
        """
            :param input_dim: Data dimension
            :param kernel_str: kernel to use, either rbf, linear, poly (see https://scikit-learn.org/stable/modules/svm.html#kernel-functions)
            :param C_param: Penalty for soft-margin override, 
        """
        super().__init__(weighted_model, train_set_size)
        
        # Auxiliary information
        self.input_dim = input_dim
        self.C_param = C_param
        self.kernel_str = kernel_str
        self.gamma = 1.0/input_dim # corresponds to the "auto" setting in sklearn.
        if kernel_str == "rbf":
            self.kernel = lambda X, Y, det=False: squared_exp_kernel(X, Y, self.gamma, deterministic=det)
        elif kernel_str == "linear":
            self.kernel = linear_kernel
        elif kernel_str == "poly":
            self.kernel = polynomial_kernel
        else:
            raise ValueError("Unsupported Kernel. Use either rbf, linear or poly.") 

        # SVM parameters
        self.support_vector_indices = torch.empty(0) # [N, input_dim] array storing the indices of the support vectors. Only needed for data weights in KernelSVM
        self.support_vector_data = torch.empty(0) # [N, input_dim] array storing the support vectors
        self.support_vector_labels = torch.empty(0) # [N] array storing the labels of vectors
        #self.input_data = torch.empty(0)
        #self.input_labels = torch.empty(0) # Input labels, transformed to +-1
        self.alpha_target = torch.empty(0) # [N] the alpha parameters for the support vectors multiplied by the target (a_i*y_i)
        self.bias = torch.empty(0)

    def num_support_vectors(self) -> int:
        """ Get number of support vectors. """
        return self.support_vector_data.size(0)


    def forward(self, x: torch.tensor) -> torch.tensor:
        return torch.sigmoid(self.predict_with_logits(x))


    def predict_with_logits(self, x: torch.tensor) -> torch.tensor:
        """ The prediction for the SVM is given by the kernels of the support vectors and the input
            data, multiplied by the fitted weigths plus and additional bias,
            y^hat = \sum_{i=1...n} \alpha_i * k(x, x_i) + bias
            where x is the query point and x_i are the training points.
        """
        #print(x.shape)
        k_premult = self.kernel(self.support_vector_data, x, det=True)
        return torch.matmul(k_premult.t(), self.alpha_target.reshape(-1, 1)).flatten() + self.bias

    def predict_from_parameters(self, x: torch.tensor, parameters: torch.tensor) -> torch.tensor:
        """
            Predict logits using different model parameters (instead the ones stored with this object) supplied as inputs.
            param x: [B, D] inputs. 
            :param parameters: (C, num_params)
            :returns: [B, C]-matrix.
        """
        k_premult = self.kernel(self.support_vector_data, x, det=True) # [SV, B]
        params_alpha = parameters[:,:self.num_support_vectors()] # [C, SV]
        params_bias = self._get_bias(parameters)
        return torch.matmul(k_premult.t(), params_alpha.t()) + params_bias.reshape(1, -1) # [B, C] + [1, C]

class KernelSVM(BaseSVM):
    """ A soft margin SVM for classification. 
        The class relies and follows the SKLearn implementation in sklearn.svm.SVC.
    """
    def __init__(self, input_dim: int,  kernel_str = "rbf", weighted_model: bool = False, train_set_size=None, C_param = 1.0):
        """
            :param input_dim: Data dimension
            :param kernel_str: kernel to use, either rbf, linear, poly (see https://scikit-learn.org/stable/modules/svm.html#kernel-functions)
            :param C_param: corresponds to C in see https://scikit-learn.org/stable/modules/svm.html#svc
        """
        super().__init__(input_dim,  kernel_str, weighted_model, train_set_size, C_param)

    def fit(self, X: torch.tensor, y: torch.tensor):
        """ 
            Fit the regression model. 
            The linear regression model (task=regression) is fitted using the OLS equations.
            The logistic regression model (task=classification)
        """
        input_labels = 2*y.long()-1 # convert to +-1
        sk_svm = SVC(C=self.C_param, kernel=self.kernel_str, degree=3, gamma='auto')
        sk_svm.fit(X.numpy(), y.numpy(), sample_weight=(None if not self.weighted_model else self.data_weights_vector))
        #print(lr.coef_, lr.intercept_)
        self.support_vector_indices = torch.tensor(sk_svm.support_, dtype=torch.long).flatten()
        self.support_vector_data = X[self.support_vector_indices]
        self.support_vector_labels = input_labels[self.support_vector_indices]
        self.alpha_target = torch.tensor(sk_svm.dual_coef_, dtype=torch.float32).flatten()
        self.bias = torch.tensor(sk_svm.intercept_, dtype=torch.float32)
        if torch.abs(self._get_bias(self.alpha_target.reshape(1,-1))-self.bias).item() > 1e-2:
            print(f"Warning: Computed bias does not match the one obtained with sklearn." +
                f"{self._get_bias(self.alpha_target.reshape(1,-1)).item()} vs. {self.bias.item()}")
        #print(self.support_vector_indices.shape, self.alpha_target.shape, self.bias.shape)

    def num_lagangians(self):
        """ Number of vectors that require a lagangian = additional number of lagangian parameters of this model.
            One for each vector that has alpha_i = C (or C*data_weight) and one for the sum condion (\sum_i alpha_i*y_i = 0)
        """
        #print(self.alpha_target*self.input_labels[self.support_vector_indices]
        return torch.sum(self._has_lagrangian()).item() + 1 

    def get_all_params(self):
        """ Return the concatenation of model weights and bias. Shape [1, SV+R].
            where SV is the number of support vectors. Additionally, R, Lagrangian parameters can be appended.
        """
        return torch.cat((self.alpha_target.reshape(1,-1), self._compute_lambdas()), dim=1)

    def loss_objective(self, parameters: torch.tensor, X: torch.tensor, y: torch.tensor, data_weights = None):
        """ Use the dual loss objective from https://scikit-learn.org/stable/modules/svm.html#svc
            For a simple weighting scheme, the alpha weights are multiplied by data_weights.
        """
        parameters = parameters.flatten()
        yb = 2*y-1 # convert to 0,1 binary
        kernel_mat = self.kernel(self.support_vector_data, self.support_vector_data) #[SV, SV]
        # self.alpha_target is the product y_i*alpha_i. To get alpha_i back, we multiply by the label.
        alpha_i = parameters[:self.num_support_vectors()] * yb[self.support_vector_indices]  #
        alpha_target_i = parameters[:self.num_support_vectors()] 
        # \sum_i alpha_i + \sum_i,j alpha_i*alpha_j*y_i*y_j*k(x_i, x_j)
        loss_plain = -0.5* torch.dot(alpha_target_i, kernel_mat.matmul(alpha_target_i.reshape(-1,1)).flatten()) + torch.sum(alpha_i)
        # Furthermore, the alpha_i are constrained to be <= C or <= C*dataweight for the weighted model.
        
        if self.weighted_model:
            constraint = self.C_param*data_weights[self.support_vector_indices]
        else:
            constraint = self.C_param
        # constraint will either be a scalar or a vector
        has_lagangian = (self.alpha_target*self.support_vector_labels == constraint)
        # To enforce this, we add a lagangian term
        if self.weighted_model:
            loss_lagrangian1 = torch.sum(parameters[self.num_support_vectors():-1] * (alpha_i[has_lagangian]-constraint[has_lagangian]))
        else:
            loss_lagrangian1 = torch.sum(parameters[self.num_support_vectors():-1] * (alpha_i[has_lagangian]-constraint))

        # finally, add a lagrangian for the constraint \sum_i alpha_target_i = 0
        loss_lagrangian2 = parameters[-1]*(torch.sum(alpha_target_i))
        # should be 0.
        #print(loss_lagrangian1, loss_lagrangian2)
        return loss_plain + loss_lagrangian1 + loss_lagrangian2

    
    def parameter_change_under_removal(self, X: torch.tensor, y: torch.tensor, ind: torch.tensor = ...):
        """ Compute the new paramters under removal of data_points.
            Because the solution is intractable, the jackknife approximation is used.
        """
        #use_jackknife approximation.
        opt_weights = self.get_all_params()
        #print(opt_weights.shape)
        jackknife_obj = self.loss_objective_for_jackknife(X, y) # function of (data weights, theta)
        j_mat = jackknife_compute_jacobians(opt_weights, jackknife_obj, self.data_weights_vector, additional_params=None).squeeze(0) # [Len params, #data weights]
        return -j_mat[:,ind].t()

    def compute_parameters_from_data_weights(self, data_weights: torch.tensor, X: torch.tensor, y: torch.tensor):
        """ Compute model weight change under soft removal of specific points.
            Note: this function should be differentiable w.r.t. data_weights to be useful in 
            end-to-end gradient descent optimization.
            For linear regression, this function will return a (differentiable) taylor approximation of the
            parameter vector using the current model parameters and data_weights as the center point, i.e.,
            parameters(data_weights) = self.parameters + J*(data_weights- self.data_weights)

        """
        assert self.weighted_model
        # Apply the infinitissimal jackknife approximation.
        opt_weights = self.get_all_params()
        #print(opt_weights.shape)
        jackknife_obj = self.loss_objective_for_jackknife(X, y) # function of (data weights, theta)
        j_mat = jackknife_compute_jacobians(opt_weights, jackknife_obj, self.data_weights_vector, additional_params=None).squeeze(0) # [Len params, #data weights]
        return self.get_all_params().detach() + j_mat.matmul(data_weights - self.data_weights_vector.detach())

    def _compute_lambdas(self):
        """ Compute the values for the lagrange variables lambda. 
            The lagrangian \mu for the constraint \sum yi*ai = 0 is equal 
        """
        kernel_mat = self.kernel(self.support_vector_data, self.support_vector_data) #[SV, SV]
        # Multiply with class weights.
        Q = kernel_mat * (self.support_vector_labels.reshape(-1,1).matmul(self.support_vector_labels.reshape(1,-1)))
        alpha_i = self.support_vector_labels*self.alpha_target
        deriv = torch.ones(len(alpha_i)) - Q.matmul(alpha_i.reshape(-1,1)).flatten()
        # parameter for the sum bound - this should be equivalent to the bias.
        mu = torch.mean(deriv[~self._has_lagrangian()]/self.support_vector_labels[~self._has_lagrangian()])
        # now calc lambdas 
        lambdas = -deriv[self._has_lagrangian()] - mu*self.support_vector_labels[self._has_lagrangian()]
        #print(lambdas)
        return torch.cat((lambdas.reshape(-1), mu.reshape(1))).reshape(1,-1)

    def _has_lagrangian(self):
        if self.weighted_model:
            has_lagangian = (self.alpha_target*self.support_vector_labels  == \
                self.C_param*self.data_weights_vector[self.support_vector_indices])
        else:
            has_lagangian = (self.alpha_target*self.support_vector_labels  == self.C_param)
        return has_lagangian

    def _get_bias(self, parameters):
        """ The bias is a dependent variable in the optimization problem and can be calculated from the alphas. 
            See https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf
            Eqn. (7.37) on p. 334
            Alpha_target: Matrix of B alpha*label product vectors. Shape [B, SV]
            :returns: Lenght B bias vector
        """
        alpha_target = parameters[:,:self.num_support_vectors()]
        if self.weighted_model:
            alpha_target_use = (alpha_target > 0) & (alpha_target < self.C_param*self.data_weights_vector[self.support_vector_indices].reshape(1,-1))
        else:
            alpha_target_use = (alpha_target > 0) & (alpha_target < self.C_param)
        alpha_target_use = alpha_target_use.long()
        use_num = torch.sum(alpha_target_use, dim=1)
        #print(use_num, alpha_target_use)
        k_premult = self.kernel(self.support_vector_data, self.support_vector_data) #[SV, SV]
        res = self.support_vector_labels.reshape(-1, 1) - torch.matmul(k_premult, alpha_target.t()) # [SV, B]
        #print(res.shape, alpha_target_use.shape)
        #print(res.t()*alpha_target_use)
        return torch.sum(res.t()*alpha_target_use, dim=1)/use_num 

class KernelLeastSquaresSVM(BaseSVM):
    """ A sparse least squares SVM for classification.
        See: G.C. Cawley, N.L.C. Talbot: "Fast exact leave-one-out cross-validation of sparse least-squares support vector machines"
        for a precise description of the model.
        This is a SVM with the support points precomputed a-priori. They do not necessarily need to be part of the data set.
    """
    def __init__(self, input_dim: int, support_points: torch.tensor, kernel_str = "rbf",
            weighted_model: bool = False, train_set_size=None, gamma = 0.1, task="classification"):
        """
            :param input_dim: Data dimension
            :param kernel_str: kernel to use, either rbf, linear, poly (see https://scikit-learn.org/stable/modules/svm.html#kernel-functions)
            :param support_points: An array of B support vectors to use. Shape [B, input_dim]
            :param gamma: The weight regularization term.
        """
        super().__init__(input_dim, kernel_str, weighted_model, train_set_size)
        self.support_vector_data = support_points
        self.gamma = gamma
        self.task = task
        self.K = self.kernel(self.support_vector_data, self.support_vector_data) # The kernel matrix

    def get_all_params(self):
        """ Return the concatenation of model weights and bias. Shape [1, D+1]. """
        return torch.cat((self.alpha_target.reshape(1,-1), self.bias.reshape(1,1)), dim=1)

    def fit(self, X: torch.tensor, y: torch.tensor):
        """ 
            Fit the SVM.
            labels are in a 0,1 binary format for classification and continous for regression.
        """
        if self.task == "classification":
            y = 2*y.float()-1 # Labels to binary (+-1)
        
        #self.input_data = X
        # See G.C. Cawley, N.L.C. Talbot, eqn 3 for the fit routine.

        R = torch.cat((torch.cat(((0.5/self.gamma)*self.K, torch.zeros(len(self.K),1)),dim=1),torch.zeros(1, len(self.K)+1)), dim=0) 

        # Solution of the weighted System
        # (R+ Z'^TZ')p = Z'T*(sqrt(w).y)
        # where Z' = [sqrt(w).K, sqrt(w)] 
        if self.weighted_model:
            use_weights = self.data_weights_vector
        else:
            use_weights = torch.ones(len(X))
        K = torch.sqrt(use_weights).reshape(-1, 1) * self.kernel(X, self.support_vector_data)
        Z = torch.cat((K, torch.sqrt(use_weights.reshape(-1,1))),dim=1) # [K, N+1]

        sysmat = R + Z.t().matmul(Z)
        rhs = Z.t().matmul(torch.sqrt(use_weights.reshape(-1,1))*y.reshape(-1,1))
        #print(torch.inverse(sysmat)[:10, :10])
        #print(sysmat.shape, rhs.shape)
        alpha_target = torch.linalg.solve(sysmat, rhs)
        self.alpha_target = alpha_target[:-1]
        self.bias = alpha_target[-1].reshape(1)

    
    def loss_objective(self, parameters: torch.tensor, X: torch.tensor, y: torch.tensor, data_weights = None):
        """ The loss objective stated in eqn 2 of Cawley/Talbot, but with data weights. """
        if self.task == "classification":
            y = 2*y.float()-1 # Labels to binary (+-1)
        parameters_beta = parameters[0, :-1]
        parameters_b = parameters[0,-1]

        first_term = 0.5*parameters_beta.reshape(1,-1).matmul(self.K.matmul(parameters_beta.reshape(-1,1))).flatten() # 0.5 beta^T *K *beta
        sqerrs = torch.pow(y-self.predict_from_parameters(X, parameters).flatten(), 2)
        if data_weights is not None:
            sqerr_sum = self.gamma*torch.sum(data_weights*sqerrs)
        else:
            sqerr_sum = self.gamma*torch.sum(sqerrs)
        return first_term + sqerr_sum.flatten()


    def parameter_change_under_removal(self, X: torch.tensor, y: torch.tensor, ind: torch.tensor = ..., explicit_inversion=True):
        """ Compute the new paramters under removal of data_points.
            The analytical solution for the problem is taken from  
            G.C. Cawley, N.L.C. Talbot: "Fast exact leave-one-out cross-validation of sparse least-squares support vector machines"
            but vectorized.
        """
        if self.task == "classification":
            y = 2*y.float()-1 # Labels to binary (+-1)

        # N = Number of support Vectors, K = Number of input data instances.
        K = self.kernel(self.support_vector_data, self.support_vector_data) # Shape [N, N]
        # expand to shape [n+1, N+1]
        R = torch.cat((torch.cat(((0.5/self.gamma)*K, torch.zeros(len(K),1)),dim=1),torch.zeros(1, len(K)+1)), dim=0) 
        Z = torch.cat((self.kernel(X, self.support_vector_data), torch.ones(len(X), 1)),dim=1) # [K, N+1]
        C = R + Z.t().matmul(Z) # Shape [N+1, N+1]
        Cinv = torch.inverse(C) # Shape [N+1, N+1]

        # Rhs = Z^T*t
        rhs = Z.t().matmul(y.reshape(-1,1)) # [Shape N+1, 1]
        # Now omit the (i)-th row from Z.
        # Rhs new = Z_(-i)^T*t
        rhs_new = rhs-Z.t()*y.reshape(1,-1) # [Shape N+1, K]
        # RHS verified.

        if explicit_inversion: 
            # Brute force compute inverses. Numerically more stable but also more expensive
            # C_-i = R +Z^T*Z - Z_i*z_i^T 
            #C_batch = Z.t().unsqueeze(2).matmul(Z.t().unsqueeze(1))
            
            C_batch = - Z.unsqueeze(2).matmul(Z.unsqueeze(1)) + C.unsqueeze(0)
            #print(C_batch.shape)
            new_params = torch.linalg.solve(C_batch, rhs_new.t()).t()
            #print(new_params.shape)
            #return 
        else:
            # Use Sherman-Morrison-Woodbury-Formula.
            rhs_change_only = Cinv.matmul(rhs_new) # C^-1 * new right hand side. [N+1, K]

            # Compute the inverses by the inversion formula.
            # We directly apply the new RHS:
            # C
            # The rank-1 matrix vector product (uv^T)x = u*(v^Tx).
            # However, in this case u is a matrix of [N+1, K] vectors of length N+1 and so is v.
            u = Cinv.matmul(Z.t()) # z_-i in each column [N+1, K]
            v = Cinv.t().matmul(Z.t()) # z_-i in each column [N+1, K]

            Cinvs = torch.matmul(u.transpose(0,1).unsqueeze(2), v.transpose(0,1).unsqueeze(1)) #[K, N+1, N+1]
        
            #print(Cinvs.shape)
            
            # v^T x
            vtx = torch.sum(v*rhs_new, dim=0) # shape [K]
            # times the u's
            uvtx = u*vtx.reshape(1, -1)
            denominator = 1.0-torch.sum(Z.t()*Cinv.matmul(Z.t()), dim=0) # Shape [k] 
            Cinvs = Cinv.unsqueeze(0) + Cinvs/denominator.reshape(-1,1,1)
            #print(Cinvs[5, :10, :10])
            addendum = uvtx/denominator.reshape(1,-1)
            new_params = rhs_change_only + addendum
        return new_params.t() - self.get_all_params()

    def compute_parameters_from_data_weights(self, data_weights: torch.tensor, X: torch.tensor, y: torch.tensor):
        """ Compute model weight change under soft removal of specific points.
            Note: this function should be differentiable w.r.t. data_weights to be useful in 
            end-to-end gradient descent optimization.
            For linear regression, this function will return a (differentiable) taylor approximation of the
            parameter vector using the current model parameters and data_weights as the center point, i.e.,
            parameters(data_weights) = self.parameters + J*(data_weights- self.data_weights)

        """
        if self.task == "classification":
            y = 2*y.float()-1 # Labels to binary (+-1)
        # Basically the fit function again...
        R = torch.cat((torch.cat(((0.5/self.gamma)*self.K, torch.zeros(len(self.K),1)),dim=1),torch.zeros(1, len(self.K)+1)), dim=0) 

        # Solution of the weighted System
        # (R+ Z'^TZ')p = Z'T*(sqrt(w).y)
        # where Z' = [sqrt(w).K, sqrt(w)] 

        K = torch.sqrt(data_weights).reshape(-1, 1) * self.kernel(X, self.support_vector_data)
        Z = torch.cat((K, torch.sqrt(data_weights.reshape(-1,1))),dim=1) # [K, N+1]

        sysmat = R + Z.t().matmul(Z)
        rhs = Z.t().matmul(torch.sqrt(data_weights.reshape(-1,1))*y.reshape(-1,1))
        #print(sysmat.shape, rhs.shape)
        alpha_target = torch.linalg.solve(sysmat, rhs)
        return alpha_target.reshape(1,-1)

    def _get_bias(self, parameters):
        return parameters[:,-1]