import torch
import torch.nn.functional as F
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def apply_p_full(p: torch.Tensor, imgs: torch.Tensor) -> torch.Tensor:
    """
    Apply a full matrix `p` to a batch of images `imgs`.

    This function reshapes the images into vectors, applies the matrix `p` to them, 
    and then reshapes them back to their original form.

    Args:
        p (torch.Tensor): Full matrix of shape (n*n, n*n) that will be applied to the images.
        imgs (torch.Tensor): Batch of images of shape (batch_size, channels, height, width).

    Returns:
        torch.Tensor: The transformed batch of images after applying the matrix.
                      Output shape is (batch_size, 1, height=n, width=n).

    Raises:
        AssertionError: If the images are not single-channel (i.e., `channels != 1`).
    """
    # Extract the batch size and dimensions of the images
    batch_size, channels, height, width = imgs.shape
    assert channels == 1, "The function currently supports single-channel images only."
    
    # Assume the images are square, so height == width
    n = height
    
    # Reshape each image in the batch into a vector of shape (batch_size, n*n)
    reshaped_tensor = imgs.view(batch_size, -1)
    
    # Perform matrix multiplication (p.t() is the transpose of matrix `p`)
    result = torch.matmul(reshaped_tensor, p.t()).view(batch_size, 1, n, n)
    
    return result


def psnr(imgs1: list, imgs2: list) -> float:
    """
    Compute the Peak Signal-to-Noise Ratio (PSNR) for a batch of images.

    PSNR is a metric used to measure the quality of a reconstructed image compared to a reference image. 
    It is commonly used in image compression and denoising tasks.

    Args:
        imgs1 (list of torch.Tensor): List of reconstructed images.
        imgs2 (list of torch.Tensor): List of true/original images.

    Returns:
        float: Average PSNR value across the batch of images.
    """
    total_psnr = 0.0
    
    # Iterate over all images in the batch
    for i in range(len(imgs1)):
        # Compute the mean squared error (MSE) between the two images
        mse = F.mse_loss(imgs1[i], imgs2[i])
        
        # Get the maximum pixel value in the original image
        max_pixel = imgs2[i].max()  # Assuming pixel values are normalized between 0 and 1
        
        # Compute PSNR for the current pair of images
        psnr_val = 20 * torch.log10(max_pixel) - 10 * torch.log10(mse)
        
        # Accumulate PSNR value
        total_psnr += psnr_val.item()
    
    # Return the average PSNR value
    return total_psnr / len(imgs1)


def bt_line_search(f, grad_f, x: torch.Tensor, direction: torch.Tensor, alpha: float = 1e-3, beta: float = 0.5, t: float = 1.0) -> float:
    """
    Backtracking line search to find an appropriate step size `t` for optimization.

    The algorithm reduces `t` until a sufficient decrease condition is satisfied, 
    ensuring that the step taken in the search direction improves the objective function.

    Args:
        f (callable): Objective function.
        grad_f (callable): Function to compute the gradient of the objective function.
        x (torch.Tensor): Current point/vector in the search space.
        direction (torch.Tensor): Search direction.
        alpha (float, optional): Parameter for sufficient decrease condition. Defaults to 1e-3.
        beta (float, optional): Rate at which `t` is reduced (0 < beta < 1). Defaults to 0.5.
        t (float, optional): Initial step size. Defaults to 1.0.

    Returns:
        float: Step size `t` that satisfies the sufficient decrease condition.
    """
    
    # Compute the gradient at the current point
    grad = grad_f(x)
    
    # Backtracking loop to reduce t until the sufficient decrease condition is met
    while f(x + t * direction) > f(x) + alpha * t * (grad.flatten()).dot(direction.flatten()):
        t *= beta  # Reduce the step size by the factor beta
    
    return t


def power_method_square_func(lin_op, n: int = 28, num_iters: int = 100, tol: float = 1e-6, device: str = 'cuda', dtype=torch.float32) -> torch.Tensor:
    """
    Power method for finding the largest singular value of a linear operator applied to square matrices.

    This iterative method estimates the dominant singular value by repeatedly applying 
    the operator to a randomly initialized vector and normalizing the result.

    Args:
        lin_op (callable): Linear operator function that takes a tensor and returns a tensor.
        n (int, optional): Dimension of the square matrix (height and width). Defaults to 28.
        num_iters (int, optional): Maximum number of iterations. Defaults to 100.
        tol (float, optional): Tolerance for convergence. Defaults to 1e-6.
        device (str, optional): Device to perform computations on (e.g., 'cuda' or 'cpu'). Defaults to 'cuda'.
        dtype (torch.dtype, optional): Data type for the computations. Defaults to torch.float32.

    Returns:
        torch.Tensor: The estimated largest singular value of the linear operator.
    """
    
    # Initialize a random vector on the specified device
    v = torch.randn((1, 1, n, n), device=device).to(dtype=dtype)
    
    # Normalize the initial vector
    v = v / torch.norm(v)
    
    # Iterate for a maximum of `num_iters` steps
    for _ in range(num_iters):
        # Apply the linear operator
        v_new = lin_op(v)
        
        # Normalize the new vector
        v_new = v_new / torch.norm(v_new)
        
        # Check for convergence based on the tolerance
        if torch.norm(v - v_new) < tol:
            break
        
        # Update the vector for the next iteration
        v = v_new
    
    # Compute the estimated largest singular value
    s = torch.norm(lin_op(v))
    
    return s


def norm_difference_learned_tilde(learned_matrix_function, L_max: float, n: int) -> torch.Tensor:
    """
    Compute the normalized difference between a learned matrix and vanilla gradient descent.

    This function uses the power method to estimate the norm of the difference.

    Args:
        learned_matrix_function (callable): Function that applies the learned update rule to an input tensor.
        L_max (float): max Lipschitz constant of training function gradients.
        n (int): Dimension of the square matrix (height and width).

    Returns:
        torch.Tensor: The normalized difference between the learned matrix and vanilla gradient descent.
    """
    
    # Use the power method to compute the norm of the difference between the learned matrix and identity/L_max
    return power_method_square_func(lambda x: learned_matrix_function(x) - x / L_max, n=n) - 1 / L_max



def num_iters_before_under_tol(conv_list, tol=1e-6):
    """
    Finds the index of the first iteration where the value in the list of f(x_t) - f^* 
    is less than the specified tolerance. If all values are above the tolerance, 
    returns the length of the list.

    Parameters:
    conv_list (list of float): A list of values f(x_t) - f^*  from an optimization process.
    tol (float, optional): The tolerance threshold to check against. Default is 1e-6.

    Returns:
    int: The index of the first value below the tolerance, or the length of the list if all values are above tol.
    """
    # Check if all elements in the list are greater than the tolerance
    if all([i > tol for i in conv_list]):
        # If true, return the length of the list (no value is below tol)
        return len(conv_list)
    
    # Find and return the index of the first value less than the tolerance
    return np.argmax(np.array(conv_list) < tol)


def num_iters_list_of_lists_before_under_tol(conv_list, tol=1e-6):
    """
    Applies the 'num_iters_before_under_tol' function to a list of functions.
    """
    # Iterate over each 'column' (each index across sublists) and apply the 'num_iters_before_under_tol' function
    return [num_iters_before_under_tol([i[j] for i in conv_list], tol) for j in range(len(conv_list[0]))]



def get_best_and_worst_case_convs(conv_list, conv_list_comparison, approx_min_list):
    """
    Identify the best and worst convergence cases based on the ratio of convergence iterations.

    This function computes the ratio of the number of iterations before 
    a tolerance is met between two sets of convergence data. The best and worst cases 
    are defined as the ones with the lowest and highest ratios, respectively.

    Args:
        conv_list (list): A list of lists containing the convergence data for different cases.
        conv_list_comparison (list): A list of lists containing the comparison convergence data.
        approx_min_list (torch.Tensor): A tensor containing the approximated minimum values to be subtracted from each case.

    Returns:
        tuple: A tuple containing:
            - best_conv_list (list): The best convergence data from `conv_list`.
            - best_conv_comparison_list (list): The best convergence data from `conv_list_comparison`.
            - best_index (int): The index of the best convergence case.
            - worst_conv_list (list): The worst convergence data from `conv_list`.
            - worst_conv_comparison_list (list): The worst convergence data from `conv_list_comparison`.
            - worst_index (int): The index of the worst convergence case.
    """
    # Subtract the approximated minimum from each convergence case
    conv_arr = np.array([np.array(conv_list[i]) - approx_min_list.cpu().numpy() for i in range(len(conv_list))])
    conv_arr_comparison = np.array([np.array(conv_list_comparison[i]) - approx_min_list.cpu().numpy() for i in range(len(conv_list_comparison))])
    
    # Calculate the ratio of the number of iterations before reaching tolerance between conv_list and conv_list_comparison
    ratios = np.array(num_iters_list_of_lists_before_under_tol(conv_arr)) / np.array(num_iters_list_of_lists_before_under_tol(conv_arr_comparison))
    
    # Find the index of the best and worst cases (lowest and highest ratio respectively)
    best_index = np.argmin(ratios)
    worst_index = np.argmax(ratios)
    
    # Extract the best and worst convergence cases from both lists
    best_conv_list = [i[best_index] for i in conv_list]
    best_conv_comparison_list = [i[best_index] for i in conv_list_comparison]
    worst_conv_list = [i[worst_index] for i in conv_list]
    worst_conv_comparison_list = [i[worst_index] for i in conv_list_comparison]
    
    return best_conv_list, best_conv_comparison_list, best_index, worst_conv_list, worst_conv_comparison_list, worst_index

