import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from util.utils import apply_p_full, norm_difference_learned_tilde, power_method_square_func

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def optimisation_p_pointwise(objective, grad_func, x, p_init, L, approx_min, max_iter=1000, tol=1e-3, verbose=True, lmbda=0):
    """
    Optimization with pointwise preconditioning using accelerated gradient descent (AGD).

    Args:
        objective (callable): Objective function to minimize.
        grad_func (callable): Function to compute gradient with respect to x.
        x (torch.Tensor): Input tensor.
        p_init (torch.Tensor): Initial pointwise preconditioning parameter.
        L (float): Smoothness constant of the objective function.
        approx_min (float): Approximate minimum value of the objective function.
        max_iter (int, optional): Maximum number of iterations. Default is 1000.
        tol (float, optional): Tolerance for convergence, based on gradient norm. Default is 1e-3.
        verbose (bool, optional): If True, prints progress and plots objective values and gradient norms. Default is True.
        lmbda (float, optional): Regularization parameter. Default is 0.

    Returns:
        torch.Tensor: Optimized pointwise preconditioning parameter p.
    """
    # Handling scalar or list for L
    try:
        L.item()  # If L is scalar, convert to list of same value for each element of x
        L = [L for _ in range(x.shape[0])]
    except AttributeError:
        pass
    
    # Initialize p and previous p values
    p = p_init.clone()
    p_prev = p_init.clone()
    grad_f_current = grad_func(x)  # Compute gradient at current x
    
    # Initial objective value
    obj_lst = [objective(x - p * grad_f_current).item()]  
    
    # Calculate step size for p update using smoothness constant L and gradient norm
    step_size_p = 1 / (lmbda + torch.mean(torch.tensor(L).to(device) * torch.max(torch.abs(grad_f_current).view(x.shape[0], -1), dim=1)[0] ** 2))

    norm_lst = []  # Store norm of gradient for each iteration
    tnew = 0
    told = 0
    p_tilde = 1 / max(L) * torch.ones_like(p, device=device)  # Initialize auxiliary p_tilde

    if verbose:
        print('L max:', max(L))
        print('Initial objective:', objective(x - p * grad_f_current).item())

    # Main optimization loop
    for _ in tqdm(range(max_iter)):
        # AGD update steps
        tnew, told = (1 + np.sqrt(1 + 4 * told**2)) / 2, tnew
        alphat = (told - 1) / tnew
        ykp = p + alphat * (p - p_prev)

        # Update x and gradient
        x_new = x - ykp * grad_f_current
        grad_new = grad_func(x_new)

        # Compute first part of gradient wrt p
        first_grad = -torch.mean(grad_new * grad_f_current, dim=0)
        p_grad = lmbda * (ykp - p_tilde) + first_grad  # Total gradient for p

        # Update p and p_prev
        p, p_prev = ykp - step_size_p * p_grad, p

        # Store objective and norm of gradient
        obj_lst.append(objective(x - p * grad_f_current).item())
        norm_lst.append(torch.norm(p_grad).item())
        
        # Check stopping condition based on gradient norm
        if _ == 0:
            init_norm = torch.norm(p_grad)
        if torch.norm(p_grad) / init_norm < tol:
            break
        
    # Plot results if verbose
    if verbose:
        print('FINAL OBJ:', obj_lst)
        plt.semilogy(-approx_min + np.array(obj_lst))
        plt.title('Objective Function Value p')
        plt.show()
        plt.semilogy(np.array(norm_lst))
        plt.title('Norm of Gradient p')
        plt.show()

    return p


def optimisation_p_full(objective, grad_func, x, p, L, approx_min, max_iter=10, tol=1e-3, verbose=True, lmbda=0):
    """
    Optimization with full preconditioning using an accelerated gradient descent approach.

    Args:
        objective (callable): Objective function to be minimized.
        grad_func (callable): Function to compute the gradient of the objective function with respect to x.
        x (torch.Tensor): Input tensor, typically the current point in optimization.
        p (torch.Tensor): Initial preconditioning parameter (e.g., a matrix or tensor).
        L (float or list of floats): Smoothness constant of the objective function, or a list of smoothness constants per sample.
        approx_min (float): Approximate minimum value of the objective function, used for tracking progress.
        max_iter (int, optional): Maximum number of iterations. Default is 10.
        tol (float, optional): Tolerance for stopping condition based on gradient norm. Default is 1e-3.
        verbose (bool, optional): If True, prints progress and plots graphs. Default is True.
        lmbda (float, optional): Regularization parameter for the preconditioner. Default is 0.

    Returns:
        torch.Tensor: Optimized preconditioning parameter.
    """
    # Ensure smoothness constant L is a list if it is a scalar
    try:
        L.item()  # Check if L is a scalar
        L = [L for _ in range(x.shape[0])]  # Convert scalar L into a list of length x.shape[0]
    except AttributeError:
        pass
    
    # Initialize variables
    p = p.to(dtype=x.dtype).detach()
    p_prev = p.clone().to(dtype=x.dtype).detach()
    grad_f_current = grad_func(x)
    grad_current_flat = grad_f_current.view(grad_f_current.shape[0], -1)

    # Initial objective value
    obj_lst = [objective(x - apply_p_full(p, grad_f_current)).item()]
    
    # Step size for preconditioning update
    step_size_p = 1/(lmbda + torch.norm(grad_f_current)**2/x.shape[0])

    norm_lst = []
    tnew = 0
    told = 0
    p_tilde = 1/max(L) * torch.eye(p.shape[0], device=device)  # Initial preconditioner guess

    # Optimization loop
    for _ in tqdm(range(max_iter)):
        # Nesterov accelerated gradient descent (AGD) update
        tnew, told = (1 + np.sqrt(1 + 4*told**2))/2, tnew
        alphat = (told - 1)/tnew

        # Intermediate variable for AGD
        ykp = p + alphat * (p - p_prev)
        x_new = x - apply_p_full(ykp, grad_f_current)
        grad_new = grad_func(x_new)
        grad_new_flat = grad_new.view(grad_new.shape[0], -1)

        # Compute the gradient for the preconditioner p
        for i in range(x.shape[0]):
            if i == 0:
                first_grad = -torch.outer(grad_new_flat[i], grad_current_flat[i]) / x.shape[0]
            else:
                first_grad -= torch.outer(grad_new_flat[i], grad_current_flat[i]) / x.shape[0]

        # Preconditioner gradient and update
        p_grad = lmbda * (ykp - p_tilde) + first_grad.view(p.shape)
        p, p_prev = ykp - step_size_p * p_grad, p
        
        # Calculate initial norm for stopping condition
        if _ == 0:
            init_norm = torch.norm(p_grad)
        
        # Log progress every 100 iterations
        if _ % 100 == 0:
            obj_lst.append(objective(x - apply_p_full(p, grad_f_current)).item())
            norm_lst.append(torch.norm(p_grad).item())
            
            # Stop if relative gradient norm is below tolerance
            if torch.norm(p_grad) / init_norm < tol:
                break
            
        # Plot results every 1000 iterations
        if _ % 1000 == 0:
            print('FINAL OBJ:', obj_lst)
            plt.semilogy(-approx_min + np.array(obj_lst))
            plt.title('Objective Function Value p')
            plt.show()

    # Final output with optional plotting
    if verbose:
        print('FINAL OBJ:', obj_lst)
        plt.semilogy(-approx_min + np.array(obj_lst))
        plt.title('Objective Function Value p')
        plt.show()
        plt.semilogy(np.array(norm_lst))
        plt.title('Norm of Gradient p')
        plt.show()

    return p.detach()


def convolution(kernel, image, same=True, groups=1):
    """
    Perform 2D convolution using PyTorch's built-in function, with optional group convolution.

    Args:
        kernel (torch.Tensor): Convolution kernel of shape (out_channels, in_channels, kernel_height, kernel_width).
        image (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
        same (bool, optional): If True, uses padding to ensure the output is the same size as the input. Default is True.
        groups (int, optional): Number of groups for grouped convolution. Default is 1.

    Returns:
        torch.Tensor: Output tensor after 2D convolution with the kernel.
    """
    # Ensure image and kernel are in 4D format (batch, channels, height, width)
    if len(image.shape) == 2:
        image = image.unsqueeze(0).unsqueeze(0)
    elif len(image.shape) == 3:
        image = image.unsqueeze(0)
    if len(kernel.shape) == 2:
        kernel = kernel.unsqueeze(0).unsqueeze(0)
    elif len(kernel.shape) == 3:
        kernel = kernel.unsqueeze(0)
    
    # Determine number of groups for convolution
    groups = min([max([kernel.shape[0], kernel.shape[1]]), max([image.shape[0], image.shape[1]])])
    num_ker = max([kernel.shape[0], kernel.shape[1]])
    num_img = max([image.shape[0], image.shape[1]])

    # Perform convolution with optional padding
    if len(image.shape) == 4 and len(kernel.shape) == 4:
        if groups != 1:
            if same:
                return torch.flip(F.conv2d(torch.flip(kernel, dims=(2,3)).view(1, num_ker, kernel.shape[-2], kernel.shape[-1]), 
                                           image.view(num_img, 1, image.shape[-2], image.shape[-1]), 
                                           padding='same', groups=groups), dims=(2,3))
            else:
                return torch.flip(F.conv2d(torch.flip(kernel, dims=(2,3)).view(1, num_ker, kernel.shape[-2], kernel.shape[-1]), 
                                           image.view(num_img, 1, image.shape[-2], image.shape[-1]), 
                                           padding=image.shape[-1]-1, groups=groups), dims=(2,3))
        else:
            if same:
                return torch.flip(F.conv2d(torch.flip(kernel, dims=(2,3)), image, padding='same'), dims=(2,3))
            else:
                return torch.flip(F.conv2d(torch.flip(kernel, dims=(2,3)), image, padding=image.shape[-1]-1), dims=(2,3))
    else:
        raise ValueError('Image shape is not correct')


def convolution_adjoint(kernel, image, groups=1):
    """
    Compute the adjoint of a convolution operation.

    Args:
        kernel (torch.Tensor): Convolution kernel.
        image (torch.Tensor): Input image tensor.
        groups (int, optional): Number of groups for grouped convolution. Default is 1.

    Returns:
        torch.Tensor: Output tensor after adjoint convolution.
    """
    w = kernel.shape[-1]
    groups = min([max([kernel.shape[0], kernel.shape[1]]), max([image.shape[0], image.shape[1]])])
    return convolution(image, torch.flip(kernel, dims=(2,3)), same=False, groups=groups)[:, :, w//2:3*w//2, w//2:3*w//2]
    
def optimisation_convolution(objective, grad_func, x, p_init, L, approx_min, kernel_width, kernel_height, max_iter=10, tol=1e-3, verbose=True, lmbda=0, L_smooth_p=0, output_L=False, manual=True):
    """
    Perform optimization using convolutional preconditioning to solve an objective function.

    Args:
        objective (callable): Objective function to minimize.
        grad_func (callable): Function to compute the gradient of the objective with respect to x.
        x (torch.Tensor): Input tensor (current variable state).
        p_init (torch.Tensor): Initial preconditioning parameter (kernel weights for convolution).
        L (float or list of floats): Smoothness constant of f (can be a list if x is a batch of inputs).
        approx_min (float): Approximate minimum value of the objective function for plotting purposes.
        kernel_width (int): Width of the convolution kernel.
        kernel_height (int): Height of the convolution kernel.
        max_iter (int, optional): Maximum number of iterations for the optimization loop. Default is 10.
        tol (float, optional): Tolerance for stopping condition based on gradient norm. Default is 1e-3.
        verbose (bool, optional): If True, print progress and plot the objective function and gradient norms. Default is True.
        lmbda (float, optional): Regularization parameter. Default is 0.
        L_smooth_p (float, optional): Precomputed smoothness constant for the preconditioning function. Default is 0.
        output_L (bool, optional): If True, return the smoothness constant L_smooth_p. Default is False.
        manual (bool, optional): If True, compute gradients manually; otherwise, use autograd. Default is True.

    Returns:
        torch.Tensor: Optimized preconditioning parameter (kernel weights).
        float: Optionally returns L_smooth_p if output_L is True.
    """
    try:
        L.item()
        L = [L for _ in range(x.shape[0])]  # Ensure L is a list for batch processing
    except AttributeError:
        pass

    # Initialize variables
    p = p_init.clone().detach().to(dtype=x.dtype)
    p_prev = p_init.clone().detach().to(dtype=x.dtype)
    grad_f_current = grad_func(x)

    # Estimate the smoothness constant L_smooth_p for the preconditioning function if not provided
    if L_smooth_p == 0:
        for i in tqdm(range(grad_f_current.shape[0])):
            L_smooth_p += L[i] * power_method_square_func(lambda x: convolution(x, grad_f_current[i].unsqueeze(0)), x.shape[-1], dtype=grad_f_current.dtype)**2 / x.shape[0]

    print('L smooth p', L_smooth_p)
    step_size_p = 1 / (lmbda + L_smooth_p)  # Step size for preconditioning update

    # Track progress of objective value and gradient norms
    obj_lst = [objective(x - convolution(grad_f_current, p)).item()]
    norm_lst = []
    tnew = 0
    told = 0
    shape_w = x.shape[-2]
    shape_h = x.shape[-1]
    N = x.shape[0]
    p_tilde = torch.zeros((1, 1, kernel_width, kernel_height)).to(device)

    # Set up the initial preconditioning kernel
    if kernel_height % 2 == 0:
        center = kernel_height // 2 - 1
    else:
        center = kernel_height // 2
    p_tilde[0, 0, center, center] = 1 / max(L)  # Initialize p_tilde as identity

    # Main optimization loop using accelerated gradient descent
    for _ in tqdm(range(max_iter)):
        tnew, told = (1 + np.sqrt(1 + 4 * told**2)) / 2, tnew
        alphat = (told - 1) / tnew

        # Update ykp (momentum term)
        if manual:
            p_to_calc = p
        else:
            p_to_calc = p.clone().requires_grad_(True)
        ykp = p_to_calc + alphat * (p_to_calc - p_prev)
        
        # Update x with the new kernel ykp
        x_new = x - convolution(grad_f_current, ykp)

        # Compute the gradient of the kernel using manual calculation or autograd
        if manual:
            new_grad = grad_func(x_new)
            aut = -torch.mean(convolution_adjoint(grad_f_current, new_grad.view(1, N, shape_w, shape_h), groups=N), dim=1).unsqueeze(0)[:, :, (shape_w - kernel_width) // 2 : (shape_w + kernel_width) // 2, (shape_w - kernel_width) // 2 : (shape_w + kernel_width) // 2]
        else:
            aut = torch.autograd.grad(objective(x_new), ykp)[0].detach()

        p_grad = lmbda * (ykp - p_tilde) + aut  # Update the gradient for the kernel
        p, p_prev = ykp - step_size_p * p_grad, p  # Update preconditioning parameters

        if _ % 10 == 0:
            # Record the objective value and gradient norm for plotting
            new_f = objective(x - convolution(grad_f_current, p)).item()
            obj_lst.append(new_f)
            norm_lst.append(torch.norm(p_grad).item())

        # Stopping condition based on gradient norm
        if _ == 0:
            init_norm = torch.norm(p_grad)
        if torch.norm(p_grad) / init_norm < tol:
            break

    # Display progress if verbose is True
    if verbose:
        print('FINAL OBJ:', obj_lst)
        plt.semilogy(-approx_min + np.array(obj_lst))
        plt.title('Objective Function Value p')
        plt.show()
        plt.semilogy(-min(obj_lst) + np.array(obj_lst))
        plt.title('Objective Function Value p difference')
        plt.show()
        plt.semilogy(np.array(norm_lst))
        plt.title('Norm of Gradient p')
        plt.show()

    if output_L:
        return p, L_smooth_p  # Return preconditioner and smoothness constant if required
    return p


def optimisation_alpha(objective, grad_func, x, alpha_init, L, approx_min, max_iter=1000, tol=1e-3, verbose=True, lmbda=0):
    """
    Perform optimization using a step size alpha for gradient descent.

    Args:
        objective (callable): Objective function to minimize.
        grad_func (callable): Function to compute the gradient of the objective with respect to x.
        x (torch.Tensor): Input tensor (current variable state).
        alpha_init (float): Initial step size for gradient descent.
        L (float or list of floats): Smoothness constant of f (can be a list if x is a batch of inputs).
        approx_min (float): Approximate minimum value of the objective function for plotting purposes.
        max_iter (int, optional): Maximum number of iterations for the optimization loop. Default is 1000.
        tol (float, optional): Tolerance for stopping condition based on gradient norm. Default is 1e-3.
        verbose (bool, optional): If True, print progress and plot the objective function and gradient norms. Default is True.
        lmbda (float, optional): Regularization parameter. Default is 0.

    Returns:
        float: Optimized step size.
    """
    try:
        L.item()
        L = [L for _ in range(x.shape[0])]  # Ensure L is a list for batch processing
    except AttributeError:
        pass

    # Initialize variables
    alpha = alpha_init
    alpha_prev = alpha_init
    grad_f_current = grad_func(x)
    obj_lst = [objective(x - alpha * grad_f_current).item()]
    norm_calc_alpha = torch.mean(torch.tensor(L).to(device) * torch.norm(grad_f_current.view(x.shape[0], -1), dim=1)**2)
    step_size_alpha = 1 / (lmbda + norm_calc_alpha)
    norm_lst = []
    tnew = 0
    told = 0
    p_tilde = torch.tensor(1 / max(L), device=device)

    # Main optimization loop using accelerated gradient descent for alpha
    for _ in tqdm(range(max_iter)):
        tnew, told = (1 + np.sqrt(1 + 4 * told**2)) / 2, tnew
        alphat = (told - 1) / tnew

        # Update alpha with momentum
        p_to_calc = alpha.clone()
        ykp = p_to_calc + alphat * (p_to_calc - alpha_prev)
        x_new = x - ykp * grad_f_current

        # Compute gradient with respect to alpha
        grad_new = grad_func(x_new)
        first_grad = -torch.sum(grad_new * grad_f_current) / x.shape[0]
        alpha_grad = lmbda * (ykp - p_tilde) + first_grad

        # Update alpha and previous alpha
        alpha, alpha_prev = ykp - step_size_alpha * alpha_grad, alpha

        # Record the objective value and gradient norm for plotting
        obj_lst.append(objective(x - alpha * grad_f_current).item())
        norm_lst.append(torch.norm(alpha_grad).item())
        
        # Stopping condition
        if _ == 0:
            init_norm = torch.norm(alpha_grad)
        if torch.norm(alpha_grad)/init_norm < tol:
            break
            
    if verbose:
        print('FINAL OBJ:', obj_lst)
        plt.semilogy(-approx_min + np.array(obj_lst))
        plt.title('Objective Function Value alpha')
        plt.show()
        plt.semilogy(np.array(norm_lst))
        plt.title('Norm of Gradient alpha')
        plt.show()

    return alpha.detach()


def find_appropriate_lambda_pointwise(obj_func, grad_func, x, curr_p, approx_min, n, L, max_iter=1000, tol=0.001, verbose=False):
    """
    Find an appropriate lambda value for a pointwise preconditioner.

    This function iteratively adjusts the regularization parameter lambda to minimize the norm difference 
    for a pointwise preconditioner in an optimization setting. It halves or doubles the lambda value based 
    on the sign of the norm difference until an appropriate value is found.

    Args:
        obj_func (function): Objective function to be optimized.
        grad_func (function): Gradient of the objective function.
        x (array): Current point in the optimization space.
        curr_p (array): Current preconditioner.
        approx_min (float): Approximate minimum value for lambda.
        n (int): Dimension of the problem.
        L (array): Lipschitz constants of the gradients.
        max_iter (int, optional): Maximum number of iterations. Defaults to 1000.
        tol (float, optional): Tolerance for convergence. Defaults to 0.001.
        verbose (bool, optional): If True, prints intermediate steps. Defaults to False.

    Returns:
        float: Appropriate lambda value.
    """
    lmbda = 1e-6

    def norm_diff_calc(lmbda):
        """Calculate the norm difference for the pointwise preconditioner at the given lambda."""
        pointwise_preconditioner = optimisation_p_pointwise(obj_func, grad_func, x, curr_p, L, approx_min, max_iter, tol, lmbda=lmbda, verbose=verbose)
        learned_matrix_function = lambda x: x * pointwise_preconditioner
        norm_diff = norm_difference_learned_tilde(learned_matrix_function, max(L), n)
        return norm_diff

    norm_diff = norm_diff_calc(lmbda)
    
    if norm_diff < 0:
        # Halve lambda until norm difference is positive
        while norm_diff < 0:
            lmbda /= 2
            norm_diff = norm_diff_calc(lmbda)
            print('Lmbda + norm diff pointwise', lmbda, norm_diff)
        return lmbda * 2
    else:
        # Double lambda until norm difference is negative or zero
        while norm_diff > 0:
            lmbda *= 2
            norm_diff = norm_diff_calc(lmbda)
            print('Lmbda + norm diff pointwise', lmbda, norm_diff)
        return lmbda


def find_appropriate_lambda_alpha(obj_func, grad_func, x, curr_p, approx_min, n, L, max_iter=1000, tol=0.001, verbose=False):
    """
    Find an appropriate lambda value for an alpha step preconditioner.

    Similar to the pointwise preconditioner, this function adjusts the lambda value to minimize 
    the norm difference for the alpha step in an optimization process.

    Args:
        obj_func (function): Objective function to be optimized.
        grad_func (function): Gradient of the objective function.
        x (array): Current point in the optimization space.
        curr_p (array): Current preconditioner.
        approx_min (float): Approximate minimum value for lambda.
        n (int): Dimension of the problem.
        L (array): Lipschitz constants of the gradients.
        max_iter (int, optional): Maximum number of iterations. Defaults to 1000.
        tol (float, optional): Tolerance for convergence. Defaults to 0.001.
        verbose (bool, optional): If True, prints intermediate steps. Defaults to False.

    Returns:
        float: Appropriate lambda value.
    """
    lmbda = 1e-6

    def norm_diff_calc(lmbda):
        """Calculate the norm difference for the alpha step preconditioner at the given lambda."""
        alpha_step = optimisation_alpha(obj_func, grad_func, x, curr_p, L, approx_min, max_iter, tol, lmbda=lmbda, verbose=verbose)
        learned_matrix_function = lambda x: alpha_step * x
        norm_diff = norm_difference_learned_tilde(learned_matrix_function, max(L), n)
        return norm_diff

    norm_diff = norm_diff_calc(lmbda)

    if norm_diff < 0:
        # Halve lambda until norm difference is positive, with a maximum of 10 iterations
        k = 0
        while norm_diff < 0:
            k += 1
            lmbda /= 5
            norm_diff = norm_diff_calc(lmbda)
            print('Lmbda + norm diff alpha', lmbda, norm_diff)
            if k == 10:
                return 0.
        return lmbda * 5
    else:
        # Multiply lambda by 5 until norm difference is negative or zero
        while norm_diff > 0:
            lmbda *= 5
            norm_diff = norm_diff_calc(lmbda)
            print('Lmbda + norm diff alpha', lmbda, norm_diff)
        return lmbda


def find_appropriate_lambda_kernel(obj_func, grad_func, x, curr_p, approx_min, n, L, kernel_width, kernel_height, max_iter=100, tol=0.001, verbose=False):
    """
    Find an appropriate lambda value for a convolution-based preconditioner (kernel).

    This function adjusts lambda for a convolution kernel preconditioner, minimizing the norm difference.

    Args:
        obj_func (function): Objective function to be optimized.
        grad_func (function): Gradient of the objective function.
        x (array): Current point in the optimization space.
        curr_p (array): Current preconditioner.
        approx_min (float): Approximate minimum value for lambda.
        n (int): Dimension of the problem.
        L (array): Lipschitz constants of the gradients.
        kernel_width (int): Width of the convolution kernel.
        kernel_height (int): Height of the convolution kernel.
        max_iter (int, optional): Maximum number of iterations. Defaults to 100.
        tol (float, optional): Tolerance for convergence. Defaults to 0.001.
        verbose (bool, optional): If True, prints intermediate steps. Defaults to False.

    Returns:
        float: Appropriate lambda value.
    """
    lmbda = 1e-8

    def norm_diff_calc(lmbda, L_smooth_p=0):
        """Calculate the norm difference for the convolution preconditioner at the given lambda."""
        kernel, L_smooth_p = optimisation_convolution(obj_func, grad_func, x, curr_p, L, approx_min, kernel_width, kernel_height, max_iter, tol, lmbda=lmbda, L_smooth_p=L_smooth_p, output_L=True, verbose=verbose)
        learned_matrix_function = lambda x: convolution(x, kernel)
        norm_diff = norm_difference_learned_tilde(learned_matrix_function, max(L), n)
        return norm_diff, L_smooth_p

    norm_diff, L_smooth_p = norm_diff_calc(lmbda)
    print('norm calc 1', norm_diff)
    
    if norm_diff < 0:
        print('norm calc 1 < 0', norm_diff)
        # Halve lambda until norm difference is positive
        while norm_diff < 0:
            lmbda /= 2
            norm_diff, L_smooth_p = norm_diff_calc(lmbda, L_smooth_p)
            print('Lmbda + norm diff conv < 0 loop', lmbda, norm_diff)
        return lmbda * 2
    else:
        print('norm calc 1 > 0', norm_diff)
        # Double lambda until norm difference is negative or zero
        while norm_diff > 0:
            lmbda *= 2
            norm_diff, L_smooth_p = norm_diff_calc(lmbda, L_smooth_p)
            print('Lmbda + norm diff conv > 0 loop', lmbda, norm_diff)
        return lmbda


def find_appropriate_lambda_full(obj_func, grad_func, x, curr_p, approx_min, n, L, max_iter=1000, tol=0.001, verbose=False):
    """
    Find an appropriate lambda value for a full preconditioner.

    This function adjusts the lambda value for a full preconditioner matrix to minimize the norm difference 
    for an optimization process. It uses a doubling or halving strategy based on the norm difference sign.

    Args:
        obj_func (function): Objective function to be optimized.
        grad_func (function): Gradient of the objective function.
        x (array): Current point in the optimization space.
        curr_p (array): Current preconditioner.
        approx_min (float): Approximate minimum value for lambda.
        n (int): Dimension of the problem.
        L (array): Lipschitz constants of the gradients.
        max_iter (int, optional): Maximum number of iterations. Defaults to 1000.
        tol (float, optional): Tolerance for convergence. Defaults to 0.001.
        verbose (bool, optional): If True, prints intermediate steps. Defaults to False.

    Returns:
        float: Appropriate lambda value.
    """
    lmbda = 1e-6

    def norm_diff_calc(lmbda):
        """Calculate the norm difference for the full preconditioner at the given lambda."""
        full_preconditioner = optimisation_p_full(obj_func, grad_func, x, curr_p, L, approx_min, max_iter, tol, lmbda=lmbda, verbose=verbose)
        learned_matrix_function = lambda x: apply_p_full(full_preconditioner, x)
        norm_diff = norm_difference_learned_tilde(learned_matrix_function, max(L), n)
        return norm_diff

    norm_diff = norm_diff_calc(lmbda)
    
    if norm_diff < 0:
        # Halve lambda until norm difference is positive
        while norm_diff < 0:
            lmbda /= 5
            norm_diff = norm_diff_calc(lmbda)
            print('Lmbda + norm full', lmbda, norm_diff)
        return lmbda * 5
    else:
        # Multiply lambda by 5 until norm difference is negative or zero
        while norm_diff > 0:
            lmbda *= 5
            norm_diff = norm_diff_calc(lmbda)
            print('Lmbda + norm full', lmbda, norm_diff)
        return lmbda


def line_search(f, grad_f, xk, pk, alpha0=1, c1=1e-3, c2=0.5, max_iter=10):
    """
    Perform a line search to find an optimal step size.

    This function implements a backtracking line search method to find a step size 
    that satisfies the Armijo condition and the curvature condition.

    Args:
        f (function): Objective function to minimize.
        grad_f (function): Gradient of the objective function.
        xk (array): Current point in the optimization space.
        pk (array): Search direction.
        alpha0 (float, optional): Initial step size. Defaults to 1.
        c1 (float, optional): Parameter for Armijo condition. Defaults to 1e-3.
        c2 (float, optional): Parameter for curvature condition. Defaults to 0.5.
        max_iter (int, optional): Maximum number of iterations. Defaults to 10.

    Returns:
        float: Optimal step size.
    """
    alpha = alpha0
    alpha_upper = 0
    alpha_low = 0
    phi0 = f(xk)
    phi_prime0 = torch.tensordot(grad_f(xk), pk, dims=4)

    for _ in range(max_iter):
        if f(xk + alpha * pk) > phi0 + c1 * alpha * phi_prime0:
            alpha_upper = alpha
            alpha = 0.5 * (alpha_upper + alpha_low)
        elif torch.tensordot(grad_f(xk + alpha * pk), pk, dims=4) < c2 * phi_prime0:
            alpha_low = alpha
            if alpha_upper == 0:
                alpha = 2 * alpha_low
            else:
                alpha = 0.5 * (alpha_upper + alpha_low)
        else:
            break

    return alpha


def replace_inf_with_zero(tensor_list):
    """
    Replace infinite tensors with zeros in a list.

    This function iterates through a list of tensors, replacing any tensor that is 
    equal to infinity with a zero tensor.

    Args:
        tensor_list (list): List of tensors to process.

    Returns:
        list: Processed list with infinite tensors replaced by zeros.
    """
    inf_tensor = torch.tensor(float('inf')).to(device)
    zero_tensor = torch.tensor(0.0).to(device)
    return [zero_tensor if torch.equal(tensor, inf_tensor) else tensor for tensor in tensor_list]


def lbfgs_two_loop_recursion(m, grad_f, sk, yk, Hk0):
    """
    Perform the two-loop recursion for L-BFGS.

    This function computes the search direction using the two-loop recursion scheme 
    of the L-BFGS algorithm.

    Args:
        m (int): Number of previous iterations to use.
        grad_f (tensor): Current gradient.
        sk (list): List of previous position differences.
        yk (list): List of previous gradient differences.
        Hk0 (function): Initial Hessian approximation.

    Returns:
        tensor: Search direction vector.
    """
    q = grad_f
    alpha = [0] * m
    rhok = replace_inf_with_zero([1 / torch.tensordot(yk[i], sk[i], dims=4) for i in range(m)])

    for i in range(m - 1, -1, -1):
        alpha[i] = rhok[i] * torch.tensordot(sk[i], q, dims=4)
        q = q - alpha[i] * yk[i]

    z = Hk0(q)

    for i in range(m):
        beta = rhok[i] * torch.tensordot(yk[i], z, dims=4)
        z = z + (alpha[i] - beta) * sk[i]

    return z


def lbfgs(f, grad_f, x0, L, memory=10, max_iter=1000):
    """
    L-BFGS optimization algorithm.

    This function implements the L-BFGS algorithm to minimize the given objective 
    function, using limited memory for storing previous gradients and positions.

    Args:
        f (function): Objective function to minimize.
        grad_f (function): Gradient of the objective function.
        x0 (tensor): Initial point for the optimization.
        L (float): Lipschitz constant for the gradient.
        memory (int, optional): Number of previous iterations to use. Defaults to 10.
        max_iter (int, optional): Maximum number of iterations. Defaults to 1000.

    Returns:
        list: History of function values at each iteration.
    """
    xk = x0.clone()
    fs = [f(xk).item()]
    shape = x0.shape
    sk = [torch.zeros(shape).to(device) for _ in range(memory)]
    yk = [torch.zeros(shape).to(device) for _ in range(memory)]
    Hk0 = lambda x: x / L

    for _ in range(max_iter):
        gfk = grad_f(xk)
        pk = -lbfgs_two_loop_recursion(memory, gfk, sk, yk, Hk0)
        alpha_k = line_search(f, grad_f, xk, pk)
        xkp1 = xk + alpha_k * pk
        
        # Update memory
        sk.pop(0)
        yk.pop(0)
        sk.append(xkp1 - xk)
        gfkp1 = grad_f(xkp1)
        yk.append(gfkp1 - gfk)

        xk = xkp1.clone()
        fs.append(f(xk).item())

    return fs


def lbfgs_all_functions(f, grad_f, x0, y, L, memory=10, max_iter=100):
    """
    Minimize multiple instances of an objective function using L-BFGS.

    This function applies the L-BFGS algorithm to each instance of the objective 
    function in a batch setting and computes the average function values.

    Args:
        f (function): Objective function to minimize.
        grad_f (function): Gradient of the objective function.
        x0 (tensor): Initial points for each instance.
        y (tensor): Additional input parameters for the objective function.
        L (float): Lipschitz constant for the gradient.
        memory (int, optional): Number of previous iterations to use. Defaults to 10.
        max_iter (int, optional): Maximum number of iterations. Defaults to 100.

    Returns:
        numpy.ndarray: Mean function values across all instances.
    """
    fs_mean = np.array([0.] * (1 + max_iter))
    
    for k in tqdm(range(x0.shape[0])):
        fs_mean += np.array(lbfgs(lambda x: f(x, y[k].unsqueeze(0)), lambda x: grad_f(x, y[k].unsqueeze(0)), x0[k].unsqueeze(0), memory=memory, max_iter=max_iter, L=L))

    return fs_mean / x0.shape[0]




