from typing import Callable, Tuple

import numpy as np
import scipy.linalg as spla
import scipy.stats as spst


def posterior_factory(kappa: np.ndarray, stepsize: np.ndarray, n: np.ndarray) -> Tuple[Callable]:
    """Implements sampling from a multivariate normal distribution. Constructs
    functions for the log-density of the normal distribution and for the
    gradient of the log-density.
   
    Args:
        kappa: the condition number
        stepsize: the step size used in the leapfrog step
        n: the dimension
    Returns:
        log_posterior: The log-density of the multivariate normal.
        grad_log_posterior: The gradient of the log-density of the multivariate
            normal distribution.
        metric: The metric for the multivariate normal.

    """
    kappa = kappa
    stepsize = stepsize
    n = n
    print("kappa",kappa)
    print("stepsize",stepsize)

    def log_posterior(x: np.ndarray) -> float:
        """Log-density of the multivariate normal distribution.

        Args:
            x: The location at which to evaluate the multivariate normal's
                log-density function.

        Returns:
             out: The value of the log-density.

        """
        vals = 0.5*x[0]**2
        for i in range(1,n):
            vals += ( kappa*x[i]**2/3 - kappa*stepsize*np.cos(x[i]/np.sqrt(stepsize)) / 3 )  
        
        return -vals
    
    def grad_log_posterior(x: np.ndarray) -> np.ndarray:
        """Compute the gradient of the log-density of the multivariate normal
        distribution.

        Args:
            x: The location at which to evaluate the gradient of the log-density.

        Returns:
            out: The gradient of the log-density.

        """
        grad = 2/3*kappa*x  + kappa*np.sqrt(stepsize)*np.sin( x/np.sqrt(stepsize)) / 3 
        grad[0] = x[0]
        
        return -grad
                                                
    def metric() -> np.ndarray:
        """Use the covariance matrix as a constant metric.

        Returns:
            Sigma: The covariance matrix of the multivariate normal distribution.

        """
        #return iSigma
        #let us retrun the identity matrix instead                                      
        return np.identity(n)
    return log_posterior, grad_log_posterior, metric
