import torch
import torch.nn as nn
from torch.autograd.functional import jacobian
from ML_Models.base_model import BaseMLModel
import math

def _get_sq_eucl_dist(X, Y):
    X = X.reshape(X.size(0), 1, X.size(1))
    Y = Y.reshape(1, Y.size(0), Y.size(1))
    diff = X-Y
    #print(diff.shape)
    return torch.sum(diff.pow(2), dim=2)

def squared_exp_kernel(X, Y, gamma=0.5, deterministic=True):
    """
        A squared exponential kernel function. See KernelRidgeRegression.__init__ for an interface description.
        Return K[i, j] = exp(-gamma*d(x_i, x_j))
    """
    if deterministic:
        dmat = _get_sq_eucl_dist(X, Y)
    else:
        dmat = torch.cdist(X.unsqueeze(0), Y.unsqueeze(0)).squeeze(0).pow(2)
      # shape [M, N]
    return torch.exp(-gamma*dmat)

def squared_exp_kernel_deriv(X, Y, gamma=0.5):
    """ 
        Compute derivative matrix gradient X [N, D], given Y [M, D]
        The shape of the result will be
        [N, M, D]
    """
    X = X.reshape(X.size(0), 1, X.size(1))
    Y = Y.reshape(1, Y.size(0), Y.size(1))
    inner = 2*(X-Y) # shape [N, M, D]
    return -gamma*torch.exp(-gamma*inner)


def linear_kernel(X, Y, deterministic=True):
    """
        A linear kernel function. See KernelRidgeRegression.__init__ for an interface description.
        Return K[i, j] = < x_i, x_j >
    """
    return X.matmul(Y.t())

def polynomial_kernel(X, Y, gamma=1.0, degree=2.0, coef0=1.0, deterministic=True):
    """
        A linear kernel function. See KernelRidgeRegression.__init__ for an interface description.
        Return K[i, j] = (gamma*< x_i, x_j >+ coef0)^degree
    """
    return torch.pow(gamma*linear_kernel(X,Y)+coef0, degree)

def polynomial_kernel_derivative(X, Y, gamma=1.0, coef0=1.0, degree=2.0):
    ret = (degree)*torch.pow(gamma*linear_kernel(X,Y)+coef0, degree-1.0).unsqueeze(-1)*Y.unsqueeze(0)
    print(ret.shape)
    return ret

def neural_tangent_kernel(X, Y, deterministic=True):
    """ Implementation of the Neural tangent kernel. 
        See Zhang et al. "Rethinking Influence Functions of Neural Networks in the Over-parameterized Regime" eqn 4, however this function
        also works for non-normalized feature vectors.
    """
    dot_prods = X.matmul(Y.t())
    dot_prodsnorm = (dot_prods/X.norm(dim=1,keepdim=True))/Y.norm(dim=1).reshape(1,-1)
    dot_prodsnorm = torch.clamp(dot_prodsnorm, -1, 1)
    #print(torch.acos(torch.min(dot_prodsnorm)), torch.acos(torch.max(dot_prodsnorm)))
    return dot_prods*((math.pi-torch.acos(dot_prodsnorm))/(2*math.pi))

def neural_tangent_kernel_derivative(X, Y):
    """ 
        Compute derivative matrix wrt X [N, D], given Y [M, D]
        The shape of the result will be
        [N, M, D]
    """
    #torch.bmm(X.unsqueeze(1), X.unsqueeze(0)) torch.eye(X.shape(1)).unsqueeze
    normsx = torch.norm(X, dim=1) # N
    normsy = torch.norm(Y, dim=1) # M
    print(normsx.shape)
    m1 = torch.eye(X.size(1)).unsqueeze(0)*normsx.pow(-1).unsqueeze(-1).unsqueeze(-1) - normsx.pow(-3).unsqueeze(-1).unsqueeze(-1)*torch.bmm(X.unsqueeze(2), X.unsqueeze(1))
    m2 = torch.matmul(m1.unsqueeze(1), (Y/normsy.reshape(-1,1)).unsqueeze(-1)).squeeze(-1) # [N 1 D D] x [M, D, 1]
    print(m2.shape) #[N, M, D]

    dot_prods = X.matmul(Y.t())
    dot_prodsnorm = (dot_prods/X.norm(dim=1,keepdim=True))/Y.norm(dim=1).reshape(1,-1)
    dot_prodsnorm = torch.clamp(dot_prodsnorm, -1, 1)

    m2 = (dot_prods.unsqueeze(-1)*m2)/torch.sqrt(1-dot_prodsnorm.pow(2)).unsqueeze(-1)
    m2 = Y.unsqueeze(0)*(math.pi-torch.acos(dot_prodsnorm)).unsqueeze(-1)+m2

    #print(torch.acos(torch.min(dot_prodsnorm)), torch.acos(torch.max(dot_prodsnorm)))
    #dot_prods*((math.pi-torch.acos(dot_prodsnorm))/(2*math.pi))
    return m2/(2*math.pi)

def compute_kernel_derivative(kernel_func, X_data, X_in):
    """ Compute the derivative of K(X_in, X_data) """
    for i in len(X_in): 
        print(i)
        X_in_grad = X_in[i].clone().reshape(-1,1).requires_grad_(True)
        grad_matrix = jacobian(lambda X_in_grad: kernel_func(X_data, X_in_grad).t(), X_in_grad, create_graph=False)
        print(grad_matrix.shape)
    return grad_matrix

"""
    Kernel Ridge Regression
"""
class KernelRidgeRegression(BaseMLModel):
    def __init__(self, kernel_func, lambd: float = 10.0, num_of_classes: int = 1, weighted_model: bool = False, train_set_size=None, kernel_deriv=None):
        """
            kernel_func: a vectorized version of the kernel function for points of dimension [input_dim]
            For inputs X (Shape[M, input_dim]) and Y (Shape [N, input_dim]), kernel func should return a
            [M, N]-matrix with the kernel evaluated between the pairs of points. This matrix is positive definite for M=N.
        """
        super().__init__(weighted_model, train_set_size)
        self.num_of_classes = num_of_classes
        self.kernel = kernel_func
        self.internal_weights = 0
        self.points = 0
        self.lambd = lambd
        self.kernel_deriv = kernel_deriv # Analytical derivative implementation for kernel function.

    def fit(self, X: torch.tensor, y: torch.tensor):
        """ 
            Fit the kernel ridge model to data X with labels Y. We have
            y^hat = k(x^pred, X)*(k(X,X)+lambda*I)^-1y
            In this function we caluclate the term (k(X,X)+lambda*I)^-1y, which we store as internal_weights, i.e.,
            internal weights = k(X,X)+lambda*I)^-1y
        """  
        self.points = X
        kernel_mat = self.kernel(self.points, self.points)
        if self.weighted_model:
            outer_weight = torch.sqrt(self.data_weights_vector).reshape(-1, 1).matmul(torch.sqrt(self.data_weights_vector).reshape(1, -1))
            dweights = self.data_weights_vector
            #print(outer_weight.shape, dweights.shape)
        else:
            outer_weight = torch.ones_like(kernel_mat)
            dweights = torch.ones(len(y))
        # For efficient computation we rearrange to (k(X,X)+lambda*I)*internal_weights = y.
        #print(torch.linalg.solve(kernel_mat*outer_weight + self.lambd*torch.eye(len(y)),
        #         (torch.sqrt(dweights)*y).reshape(-1, 1)).shape)
        self.internal_weights = torch.sqrt(dweights)*torch.linalg.solve(kernel_mat*outer_weight + self.lambd*torch.eye(len(y)),
                 (torch.sqrt(dweights)*y).reshape(-1, 1)).flatten()
        print("kernel weights shape:", self.internal_weights.shape)

    def forward(self, x: torch.tensor) -> torch.tensor:
        return torch.sigmoid(self.predict_with_logits(x))

    def predict_with_logits(self, x: torch.tensor) -> torch.tensor:
        k_premult = self.kernel(self.points, x, deterministic=True)
        return torch.matmul(k_premult.t(), self.internal_weights.reshape(-1, 1)).flatten()

    def get_all_params(self):
        return self.internal_weights.reshape(1,-1)

    def predict_from_parameters(self, x: torch.tensor, parameters: torch.tensor):
        """
            Predict logits with different weights, but using the data points stored with this model.
            x: (B, input_dim) inputs. 
            w: weights of shape [C, len(parameters)] or just [len(self.points)] for using only single weights.
            return [B, C]-matrix. 
        """
        if len(parameters.shape) == 1:
            parameters.reshape(1, -1)
        k_premult = self.kernel(self.points, x, deterministic=True) # [len(self.points), B]
        return torch.matmul(k_premult.t(), parameters.t())

    def jac_kx(self, x_cf):
        """
            Compute Jacobian of the kernel function k_X(x).
            x_cf: Batched Input points [B, input_dim]
            return: Batched jacobian [B, N, input_dim] of self.kernel(x_cf, self.points) w.r.t x_cf
        """
        x_cfgrad = x_cf.clone().requires_grad_()
        func = lambda x: torch.sum(self.kernel(x, self.points), dim=0)  # output shape N (self.npoints)
        ret = jacobian(func, x_cfgrad,)

        print('Jacobian shape:', ret.shape)
        return ret.transpose(0, 1)


    def compute_parameters_from_data_weights(self, data_weights: torch.tensor, X: torch.tensor, y: torch.tensor):
        """ Compute model weight change under soft removal of specific points.
            Note: this function should be differentiable w.r.t. data_weights to be useful in 
            end-to-end gradient descent optimization.
            For linear regression, this function will return a (differentiable) taylor approximation of the
            parameter vector using the current model parameters and data_weights as the center point, i.e.,
            parameters(data_weights) = self.parameters + J*(data_weights- self.data_weights)

        """
        assert self.weighted_model
        outer_weight = torch.sqrt(data_weights).reshape(-1, 1).matmul(torch.sqrt(data_weights).reshape(1, -1))
        dweights = data_weights

        # For efficient computation we rearrange to (k(X,X)+lambda*I)*internal_weights = y.
        #print(torch.linalg.solve(kernel_mat*outer_weight + self.lambd*torch.eye(len(y)),
        #         (torch.sqrt(dweights)*y).reshape(-1, 1)).shape)
        kernel_mat = self.kernel(X, X)
        internal_weights = torch.sqrt(dweights)*torch.linalg.solve(kernel_mat*outer_weight + self.lambd*torch.eye(len(y)),
                 (torch.sqrt(dweights)*y).reshape(-1, 1)).flatten()

        return internal_weights

    def parameter_change_under_removal(self, X: torch.tensor, y: torch.tensor, ind: torch.tensor = Ellipsis):
        """
            Compute model weight change.
            To accomplish that we use the leave-one-out weights from equation 9 in "Rethinking Influence Functions of Neural Networks
            in the Over-parameterized Regime" by Zhang and Zhang (2021) (https://arxiv.org/pdf/2112.08297.pdf).
            X: data (unused)
            y: Labels assigned to all data points.
            ind: pass data point indices here (as 1D torch index tensor), if your are only interested in the change when removing some specific points. 
                Otherwise changes for all points will be computed.
            return: A matrix, where column n represents the change of the internal_weights under removal of point n.
        """
        # The change in the weights (K^-1 * Y) is given by k_{-i}*k_{-i}^T * Y/k_{-ii} 
        kernel_mat = (self.kernel(self.points, self.points) + self.lambd*torch.eye(len(self.points)))
        with torch.no_grad():
            kernel_mat_inv = torch.linalg.inv(kernel_mat) 
        #kernel_mat_inv = torch.eye(len(y))
        diag_select = torch.diag(kernel_mat_inv)[ind] 
        kernel_mat_select = kernel_mat_inv[:, ind] # Select the required columns [len(self.points), len(ind)]
        kity = kernel_mat_select.t().matmul(y.reshape(-1, 1))/diag_select.reshape(-1, 1) # shape [len(ind), 1]
        diff = -kernel_mat_select * kity.reshape(1, -1)
        return diff.t()

