# Implement the infinitissimal jackknife to allow 
# differentiation through the stacked optimization problems using torch.autograd

from numpy import vectorize
from torch.autograd.functional import jacobian, hessian
# from torch.autograd import grad
# from torch.linalg import inv
from ML_Models.KernelRidgeRegression.model import KernelRidgeRegression
import torch

"""
We add some intuition of what we aim to accomplish with the lines of code below:
(1) Suppose we solve the classical Wachter-Style recourse objective. The optimal counterfactual (call this theta^*)
 will then be a function of the fixed model parameters (alpha) (for now, we do not use data weights for each
 training instance).
(2) To analyze the sensitivity of the optimal recourse (theta*) with respect to model parameter changes (alpha), we
want to compute (d theta^*(alpha) / d alpha). However, theta^* is an implicit function of the model parameters alpha.
Thus, we need to find a way to compute (d theta^*(alpha) / d alpha). How can we do this?
(3) This is where the jacknife comes into play! To compute (d theta^*(alpha) / d alpha) we have to realize that we can
differentiate the counterfactual explanation at the optimum with respect to the model parameters alpha.
Using the chain rule of multivariate calculus we can then easily derive (d theta^*(alpha) / d alpha) (the derivation is
akin to the one seen in equation (2) in "A Higher-Order Swiss Army Infinitesimal Jackknife" by Giordano et al (2019)
(https://arxiv.org/pdf/1907.12116.pdf).
"""


def jackknife_make_differentiable(theta_star, objective, alpha_in, additional_params):
    """ Compute the solution theta* of an optimization problem 
             theta* = argmin_theta objective(alpha, theta)
        in a differentiable way (using its linear approximation) wrt alpha.
        This function returns value equivalent to theta_star, but as a differentiable version theta_star(alpha_in).

        theta_star: torch.tensor (B, N) Batched optimas of size N
        objective: function of (theta, alpha), with input size (B, N) for theta, (B, M) for alpha that returns (B) 
            objective values (all are tensors.) Objective must be differentiable torch function in both two arguments.
            if additional params is not None, a third argument, additional_params[i] is passed to the objective for each theta_star[i]
        alpha_in: torch.tensor (M) with the optimization problem parameters used to find the corresponding theta_star.
        theta is the parameter that the problem is minimized over and alpha are the parameters of the problem itself.
        This function allows to compute the derivative of theta* w.r.t. alpha.
        alpha_in: The values of alpha were used to minimize the objective and resulted in theta_star (but theta star is not differentiable).
    """
    # jackknife approximated jacobian.
    jknife_approx_jac = jackknife_compute_jacobians(theta_star, objective, alpha_in, additional_params)
    return theta_star + jknife_approx_jac.matmul((alpha_in-alpha_in.detach()))


def jackknife_compute_jacobians(theta_star, objective, alpha_in, additional_params):
    """ Compute the solution theta* of an optimization problem 
             theta* = argmin_theta objective(alpha, theta)
        in a differentiable way (using its linear approximation) w.r.t alpha.
    
        theta_star: torch.tensor (B, N) Batched optimas of size N
        objective: function of (alpha, theta), with input size (M) for alpha, (C) for theta returns (1) 
            objective values (all are tensors.) Objective must be differentiable torch function in both its arguments.
            if additional params is not None, a third argument, additional_params[i] is passed to the objective for each theta_star[i]
        alpha_in: torch.tensor (M) with the optimization problem parameters used to find the corresponding theta_star.
        theta is the parameter that the problem is minimized over and alpha are the parameters of the problem itself.
        This function allows to compute the derivative of theta* w.r.t. alpha.
        alpha_in: The values of alpha were used to minimize the objective and resulted in theta_star (but theta star is not differentiable).
    """
    hessians_tt_list = []
    hessians_at_list = []
    alpha_indet = alpha_in.detach()
    for i in range(theta_star.size(0)):
        #print(i)
        if additional_params is not None:
            objective_a = (lambda theta: objective(alpha_indet, theta, additional_params[i]))
        else:
            objective_a = (lambda theta: objective(alpha_indet, theta))  # The objective with fixed alpha.
        h_tt = hessian(objective_a, theta_star[i])
        hessians_tt_list.append(h_tt)
    
        # Now compute jacobian wrt. gradients.
        if additional_params is not None:
            dodx = lambda theta_in, alpha_val: jacobian(lambda theta, alpha: objective(alpha,
                                                                                    theta,
                                                                                    additional_params[i]),
                                                        (theta_in, alpha_val), create_graph=True)
        else:
            dodx = lambda theta_in, alpha_val: jacobian(lambda theta, alpha: objective(alpha,
                                                                                    theta),
                                                        (theta_in, alpha_val), create_graph=True)
        
        dodx_lonly = lambda alpha: dodx(theta_star[i], alpha)
        h_at = jacobian(dodx_lonly, alpha_indet)[0].squeeze(0)
        hessians_at_list.append(h_at)
        # print(h_at.shape)
        # print(h_tt.shape)
    h_tt = torch.stack(hessians_tt_list)
    #print(h_tt[0,:10,:10])
    h_tt = torch.linalg.inv(h_tt)
    h_at = torch.stack(hessians_at_list)
    jknife_approx_jac = torch.bmm(h_tt, h_at)  # batch matrix-matrix product of matrices
    return jknife_approx_jac


def scfe_recourse_objective(kernel_model: KernelRidgeRegression, target_value: float, lambda_ = 1.0):
    """ Return the counterfactual objective for a kernel model. 
        Inputs: Factual of shape (N)
        target_value: The target value t in the objective

        g(theta, x', x) = lambda * || x- x'|| + (t-f_theta(a,x))^2

        return g
    """
    return lambda a, x, factual: torch.norm(x-factual) + (kernel_model.predict_from_parameters(x.reshape(1,-1), a.reshape(1,-1)).resize(1) - target_value).pow(2)









