import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import random_split
from util.objective_functions import fast_huber_TV, fast_huber_grad, huber
from util.utils import bt_line_search, psnr, num_iters_before_under_tol, get_best_and_worst_case_convs
from util.optimisation_functions import optimisation_p_pointwise, optimisation_alpha, optimisation_convolution, find_appropriate_lambda_alpha, find_appropriate_lambda_pointwise, find_appropriate_lambda_kernel, convolution, lbfgs_all_functions
from util.visualising_learned_matrices import visualise_learned_scalars, visualise_learned_matrices
import torchvision

# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.0025
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 1e-05 # scalar multiplying the regularizer
N_TRAIN = 100 # number of images (functions) to train over
N_VAL = 100 # number of images (functions) to test over
n = 96 ## side length of image (assumed to be square)


# defining the kernel for the forward operator
def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
    """Creates a 2D Gaussian kernel."""
    x = torch.arange(size).float() - size // 2
    x = x.view(-1, 1)
    y = torch.arange(size).float() - size // 2
    y = y.view(1, -1)
    kernel_2d = torch.exp(-0.5 * (x**2 + y**2) / sigma**2)
    kernel_2d = kernel_2d / kernel_2d.sum()
    return kernel_2d.unsqueeze(0).unsqueeze(0)

size = 5
sigma = 1.5
gaussian_kernel = gaussian_kernel(size, sigma).to(device)

def forward_operator(x):
    return convolution(x, gaussian_kernel)

adjoint_operator = forward_operator

def init_recon(x):
    return x

operator_norm = 1.0

## Calculate the Lipschitz constant of gradient
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')



# Load datasets
try:
    train_dataset = torch.load(f'train_dataset_blur_{N_TRAIN}.pt')
    val_dataset = torch.load(f'val_dataset_blur_{N_VAL}.pt')
except:
    dataset = torchvision.datasets.STL10('STL10_test', split='test', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])


# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)

# Define the overall objective function - mean over batch size
def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]

# Gradient functions for the objective function - without dividning by x.shape[0] - this information is absorbed into L
def grad_objective(x, y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)


## individual functions returns a tensor of size batch_size, which is the objective value for each function
def fast_huber_TV_individual(x, delta=0.01):
    ## in a batch it just adds all individuals up
    tv_h = torch.sum(huber(x[:, :, 1:,:]-x[:, :, :-1,:], delta), dim=(1,2,3))
    tv_w = torch.sum(huber(x[:, :, :,1:]-x[:, :, :,:-1], delta), dim=(1,2,3))
    huber_tv = (tv_h+tv_w)
    return reg_const * huber_tv

def objective_function_individual(x, y):
    return 0.5 * torch.norm((forward_operator(x) - y).view(x.shape[0], -1), dim=1) ** 2 + fast_huber_TV_individual(x, delta=huber_const)


one_over_L = 1 / L

calculate_final_vals = False ## whether to calculate the final values of the dictionaries with provable convergence guarantees
test_train = True ## whether to test on the training set
calculate_final_kernel = False ## whether to calculate the final kernel value with provable convergence guarantees

L = [L for _ in range(N_TRAIN)]


# Load learned dictionaries
## If calculate_final_vals is True, then one more iteration is used for learning parameters so that the learned algorithm is convergent

dict_kernel = torch.load('learned_operators/kernel_dictionary_blur.pt')
if calculate_final_vals:
    dict_pointwise = torch.load('learned_operators/pointwise_dictionary_blur.pt')
    dict_alpha = torch.load('learned_operators/alpha_dictionary_blur.pt')
else:
    dict_pointwise = torch.load('learned_operators/pointwise_dictionary_blur_lambda.pt')
    dict_alpha = torch.load('learned_operators/alpha_dictionary_blur_lambda.pt')


kernel_height = dict_kernel[0].shape[-2]
kernel_width = dict_kernel[0].shape[-1]




## Some visualisations of learned parameters
visualise_learned_scalars(dict_alpha, 2/max(L))

visualise_learned_matrices(dict_pointwise, indices=[5,6,7,8])

visualise_learned_matrices(dict_kernel, indices=[0,2,5,15,100])

visualise_learned_matrices({i: torch.max(-5*torch.ones_like(dict_kernel[0]), torch.min(5*torch.ones_like(dict_kernel[0]), dict_kernel[i])) for i in dict_kernel}, indices=[0,2,5,15,100])
visualise_learned_matrices({i: torch.max(-5*torch.ones_like(dict_pointwise[0]), torch.min(5*torch.ones_like(dict_pointwise[0]), dict_pointwise[i])) for i in dict_pointwise}, indices=[5,6,7,8])




'''
Code to show that the kernels get closer over time in general

iterations = 1 + np.array(range(len(dict_kernel) - 1))
relative_changes = [(torch.norm(dict_kernel[i + 1] - dict_kernel[i]) / torch.norm(dict_kernel[i])).item() for i in range(len(dict_kernel) - 1)]
plt.figure(figsize=(10, 6))
plt.loglog(iterations, relative_changes, linestyle='-', color='blue', linewidth=2, markersize=6)
plt.grid(True, which="both", ls="--", linewidth=0.5)
#plt.title('Relative Change in Kernel Over Iterations', fontsize=16, fontweight='bold')
plt.xlabel(r'Iteration $t$', fontsize=20)
plt.ylabel(r'$\frac{\| \kappa_t - \kappa_{t-1} \|_2}{\| \kappa_{t-1} \|_2}$', fontsize=20)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()
plt.show()
'''



if __name__ == '__main__':
    
    if test_train:
        ## for the training data
        for i, data in enumerate(data_loader):
            
            
            xs, labels = data
            xs = xs.mean(dim=1, keepdim=True).to(device)

            plt.rcdefaults()
            
            # Display the true image
            plt.imshow(xs[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            plt.title('True Image')
            plt.show()

            # Create clean and noisy observations
            clean_observations = forward_operator(xs)
            noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))

            # Display the noisy observation
            plt.imshow(noisy_observations1[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            plt.title('Noisy Observations')
            plt.show()

            # Initial reconstruction from noisy observations
            inpt = init_recon(noisy_observations1)



            ## calculate the lambda value that would make the next learned preconditioner lead to guaranteed convergence for the other methods
            if calculate_final_vals:
                
                inpt_get_lmbda_pointwise = inpt.clone()
                inpt_get_lmbda_alpha = inpt.clone()
                
                ## iterate up to the end of the learned dictionaries
                for k in tqdm(range(len(dict_kernel))):
                    inpt_get_lmbda_pointwise = inpt_get_lmbda_pointwise - dict_pointwise[k] * grad_objective(inpt_get_lmbda_pointwise, noisy_observations1)
                    inpt_get_lmbda_alpha = inpt_get_lmbda_alpha - dict_alpha[k] * grad_objective(inpt_get_lmbda_alpha, noisy_observations1)
                
                ## find lambda value
                lmbda_pointwise = find_appropriate_lambda_pointwise(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_pointwise, dict_pointwise[len(dict_kernel)-1], 0, n, L, tol=0, max_iter=100)
                ## solve the optimisation problem to find the pointwise operator with this lambda
                dict_pointwise[k+1] = optimisation_p_pointwise(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_pointwise, dict_pointwise[len(dict_kernel)-1], L, 0, max_iter=1000, tol=0.001, lmbda=lmbda_pointwise)
                
                lmbda_alpha = find_appropriate_lambda_alpha(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_alpha, torch.tensor(dict_alpha[len(dict_alpha)-1]).to(device), 0, n, L, tol=0, max_iter=100)
                dict_alpha[k+1] = optimisation_alpha(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_alpha, torch.tensor(dict_alpha[len(dict_alpha)-1]).to(device), L, 0, max_iter=1000, tol=0, lmbda=lmbda_alpha).item()
                
                torch.save(dict_pointwise, 'pointwise_dictionary_blur_lambda.pt')
                torch.save(dict_alpha, 'alpha_dictionary_blur_lambda.pt')
            


            final_training_iter = len(dict_kernel) - 1
            
            # Initialize objective values for different methods
            obj_pointwise = [objective_function(inpt, noisy_observations1).item()]
            obj_convolution = [objective_function(inpt, noisy_observations1).item()]
            obj_scalar = [objective_function(inpt, noisy_observations1).item()]


            # Clone initial reconstruction for different methods
            inpt_pointwise = inpt.clone()
            inpt_convolution = inpt.clone()
            inpt_scalar = inpt.clone()

            # Variables for NAG method
            tnew = 0
            told = 0
            obj_nag = [objective_function(inpt, noisy_observations1).item()]
            obj_nag_indiv = [[i.item() for i in objective_function_individual(inpt, noisy_observations1)]]
            psnr_nag = [psnr(inpt, xs)]
            inpt_nag = inpt.clone()
            inpt_nag_m1 = inpt.clone()


            for k in tqdm(range(1001)):
                tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
                alphat = (told - 1) / tnew
                yk = inpt_nag + alphat * (inpt_nag - inpt_nag_m1)
                grad_nag = grad_objective(yk, noisy_observations1)
                inpt_nag, inpt_nag_m1 = yk - one_over_L * grad_nag, inpt_nag
                obj_nag.append(objective_function(inpt_nag, noisy_observations1).item())
                obj_nag_indiv.append([i.item() for i in objective_function_individual(inpt_nag, noisy_observations1)])
                

                grad_pointwise = grad_objective(inpt_pointwise, noisy_observations1)
                grad_convolution = grad_objective(inpt_convolution, noisy_observations1)
                grad_scalar = grad_objective(inpt_scalar, noisy_observations1)

                # Update inputs using different methods
                idx = min(k, len(dict_kernel)-1)
                inpt_pointwise = inpt_pointwise - dict_pointwise[idx] * grad_pointwise
                inpt_scalar = inpt_scalar - dict_alpha[idx] * grad_scalar
                inpt_convolution = inpt_convolution - convolution(grad_convolution, dict_kernel[idx])
                
                # Append objective values
                obj_pointwise.append(objective_function(inpt_pointwise, noisy_observations1).item())
                obj_convolution.append(objective_function(inpt_convolution, noisy_observations1).item())
                obj_scalar.append(objective_function(inpt_scalar, noisy_observations1).item())

            approx_min = min([min(obj_nag), min(obj_pointwise), min(obj_scalar), min(obj_convolution)])
            approx_min_list = objective_function_individual(inpt_convolution, noisy_observations1)
            objective_difference_iter0 = objective_function(inpt, noisy_observations1).item() - approx_min
            
    
    ## now for the test data
    for i, data in enumerate(val_dataloader):
        
        # Preprocess the images
        xs_val, labels_val = data
        xs_val = xs_val.mean(dim=1, keepdim=True).to(device)
        
        plt.rcdefaults()

        # Create clean and noisy observations
        clean_observations_val = forward_operator(xs_val)
        noisy_observations1_val = clean_observations_val + (noise_level * torch.randn_like(clean_observations_val) * torch.mean(torch.abs(clean_observations_val).view(xs_val.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))

        # Initial reconstruction from noisy observations
        inpt_val = init_recon(noisy_observations1_val)
        
        # Initialize objective values for different methods
        obj_val_pointwise = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_full = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_scalar = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_one_over_l = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_backtracking = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_pointwise_risk = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution_risk = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_scalar_risk = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution_convergence = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution_indiv = [[i.item() for i in objective_function_individual(inpt_val, noisy_observations1_val)]]
        

        # Clone initial reconstruction for different methods
        inpt_val_pointwise = inpt_val.clone()
        inpt_val_full = inpt_val.clone()
        inpt_val_convolution = inpt_val.clone()
        inpt_val_scalar = inpt_val.clone()
        inpt_val_one_over_l = inpt_val.clone()
        inpt_val_backtracking = inpt_val.clone()
        inpt_val_pointwise_risk = inpt_val.clone()
        inpt_val_convolution_risk = inpt_val.clone()
        inpt_val_scalar_risk = inpt_val.clone()
        inpt_val_convolution_convergence = inpt_val.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_val_nag = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_nag_indiv = [[i.item() for i in objective_function_individual(inpt_val, noisy_observations1_val)]]
        psnr_nag = [psnr(inpt_val, xs_val)]
        inpt_val_nag = inpt_val.clone()
        inpt_val_nag_m1 = inpt_val.clone()


        for k in tqdm(range(1001)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_val_nag + alphat * (inpt_val_nag - inpt_val_nag_m1)
            grad_nag = grad_objective(yk, noisy_observations1_val)
            inpt_val_nag, inpt_val_nag_m1 = yk - one_over_L * grad_nag, inpt_val_nag
            obj_val_nag.append(objective_function(inpt_val_nag, noisy_observations1_val).item())
            obj_val_nag_indiv.append([i.item() for i in objective_function_individual(inpt_val_nag, noisy_observations1_val)])

            grad_pointwise = grad_objective(inpt_val_pointwise, noisy_observations1_val)
            grad_full = grad_objective(inpt_val_full, noisy_observations1_val)
            grad_convolution = grad_objective(inpt_val_convolution, noisy_observations1_val)
            grad_scalar = grad_objective(inpt_val_scalar, noisy_observations1_val)
            grad_one_over_l = grad_objective(inpt_val_one_over_l, noisy_observations1_val)
            grad_backtracking = grad_objective(inpt_val_backtracking, noisy_observations1_val)
            grad_pointwise_risk = grad_objective(inpt_val_pointwise_risk, noisy_observations1_val)
            grad_convolution_risk = grad_objective(inpt_val_convolution_risk, noisy_observations1_val)
            grad_scalar_risk = grad_objective(inpt_val_scalar_risk, noisy_observations1_val)
            grad_convolution_convergence = grad_objective(inpt_val_convolution_convergence, noisy_observations1_val)

            # Update inputs using different methods
            idx_kernel = min(k, len(dict_kernel)-1)
            idx = min(k, len(dict_pointwise)-1)
            idx_risk = min(k, len(dict_pointwise)-2)

            inpt_val_pointwise = inpt_val_pointwise - dict_pointwise[idx] * grad_pointwise
            inpt_val_scalar = inpt_val_scalar - dict_alpha[idx] * grad_scalar
            inpt_val_convolution = inpt_val_convolution - convolution(grad_convolution, dict_kernel[idx_kernel])
            inpt_val_one_over_l = inpt_val_one_over_l - one_over_L * grad_one_over_l 
            inpt_val_pointwise_risk = inpt_val_pointwise_risk - dict_pointwise[idx_risk] * grad_pointwise_risk
            inpt_val_scalar_risk = inpt_val_scalar_risk - dict_alpha[idx_risk] * grad_scalar_risk
            if k < len(dict_kernel):
                obj_val_convolution_indiv.append([i.item() for i in objective_function_individual(inpt_val_convolution, noisy_observations1_val)])


            obj_val_convolution_convergence.append(objective_function(inpt_val_convolution_convergence, noisy_observations1_val).item())
            
            #alpha_bt = torch.stack([torch.tensor(bt_line_search(lambda x: objective_function(x, noisy_observations1_val[i].unsqueeze(0)), lambda x: grad_objective(x, noisy_observations1_val[i].unsqueeze(0)), inpt_val_backtracking[i].unsqueeze(0), -grad_backtracking[i].unsqueeze(0), t=5, beta=0.5, alpha=1e-1)).to(device) for i in range(inpt_val.shape[0])]).unsqueeze(1).unsqueeze(1).unsqueeze(1)
            #inpt_val_backtracking = inpt_val_backtracking - alpha_bt * grad_backtracking
            
            # Append objective values
            obj_val_pointwise.append(objective_function(inpt_val_pointwise, noisy_observations1_val).item())
            obj_val_convolution.append(objective_function(inpt_val_convolution, noisy_observations1_val).item())
            obj_val_scalar.append(objective_function(inpt_val_scalar, noisy_observations1_val).item())
            obj_val_one_over_l.append(objective_function(inpt_val_one_over_l, noisy_observations1_val).item())
            obj_val_backtracking.append(objective_function(inpt_val_backtracking, noisy_observations1_val).item())
            obj_val_pointwise_risk.append(objective_function(inpt_val_pointwise_risk, noisy_observations1_val).item())
            obj_val_convolution_risk.append(objective_function(inpt_val_convolution_risk, noisy_observations1_val).item())
            obj_val_scalar_risk.append(objective_function(inpt_val_scalar_risk, noisy_observations1_val).item())
            
            # if k in [20,  25, 30, 40, 60]:
            #     print('\n\n\n\n\n\n\n\n\n')
            #     print(k)
            #     psnrs_conv = [psnr(inpt_val_convolution[zz,:,:,:], xs_val[zz,:,:,:]) for zz in range(xs_val.shape[0])]
            #     psnrs_nag = [psnr(inpt_val_nag[zz,:,:,:], xs_val[zz,:,:,:]) for zz in range(xs_val.shape[0])]
            #     diffs = [psnrs_conv[zz] - psnrs_nag[zz] for zz in range(xs_val.shape[0])]
            #     max_diff_index = diffs.index(max(diffs))
            #     second_max_diff_index = diffs.index(sorted(diffs)[-2])
            #     print(max_diff_index)
            #     print(f'PSNRS: {psnrs_conv[max_diff_index]}, {psnrs_nag[max_diff_index]}')
            #     ## now plot the images and the reconstructions
            #     plt.imshow(xs_val[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('True Image')
            #     plt.show()
                
  
                
            #     plt.imshow(noisy_observations1_val[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('Noisy Observations')
            #     plt.show()
                
            #     plt.imshow(inpt_val_convolution[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('Convolution Reconstruction')
            #     plt.show()
                
            #     plt.imshow(inpt_val_nag[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('NAG Reconstruction')
            #     plt.show()
                
            #     print(f'PSNR Convolution: {psnrs_conv[max_diff_index]}')
            #     psnr_compare = psnrs_conv[max_diff_index]
            #     if k == 30:
            #         diff_idx = max_diff_index
            # if k > 30 and k % 5 == 0:
            #     plt.imshow(inpt_val_nag[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title(f'NAG Reconstruction - {k} iterations')
            #     plt.show()
            #     print(f'PSNR NAG: {psnr(inpt_val_nag[max_diff_index,:,:,:], xs_val[max_diff_index,:,:,:])}, PSNR Convolution: {psnr_compare}')
            
        
        #obj_val_lbfgs = lbfgs_all_functions(objective_function, grad_objective, inpt_val, noisy_observations1_val, L[0], memory=10, max_iter=len(dict_kernel))

        approx_min_val = min(obj_val_nag)
        approx_min_val_list = objective_function_individual(inpt_val_nag, noisy_observations1_val)

        objective_difference_iter0_val = objective_function(inpt_val, noisy_observations1_val).item() - approx_min_val

        
        
        #### GENERALISATION PLOT:
        if test_train:
            iterations = np.arange(0, len(dict_kernel))
            plt.figure(figsize=(10, 8), dpi=300)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_pointwise)[iterations])/objective_difference_iter0, label='Learned Pointwise - Train', linestyle='-', color='red', linewidth=2)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_scalar)[iterations])/objective_difference_iter0, label='Learned Scalar - Train', linestyle='-', color='blue', linewidth=2)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution)[iterations])/objective_difference_iter0, label='Learned Convolution - Train', linestyle='-', color='black', linewidth=2)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_pointwise)[iterations])/objective_difference_iter0_val, label='Learned Pointwise - Test', linestyle='--', color='red', linewidth=4)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_scalar)[iterations])/objective_difference_iter0_val, label='Learned Scalar - Test', linestyle='--', color='blue', linewidth=4)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_convolution)[iterations])/objective_difference_iter0_val, label='Learned Convolution - Test', linestyle='--', color='black', linewidth=4)
            plt.xlabel(r'Iteration $t$', fontsize=24)
            plt.xlabel('Iteration', fontsize=24)
            plt.ylabel('Optimality in Function Value', fontsize=30)
            plt.ylim(1e-8, None)  # Set lower limit to 10^-7
            plt.tick_params(axis='both', which='major', labelsize=24)
            plt.legend(fontsize=20, loc='lower left')
            plt.grid(True, linestyle='-', linewidth=1)
            plt.tight_layout()
            plt.show()
                

        # Plot convergence of different methods
        iterations = np.arange(0, len(dict_kernel))
        plt.figure(figsize=(10, 8), dpi=300)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_pointwise)[iterations])/objective_difference_iter0_val, label='Learned Pointwise', linestyle='-', color='red', linewidth=3)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_scalar)[iterations])/objective_difference_iter0_val, label='Learned Scalar', linestyle='-', color='blue', linewidth=3)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_convolution)[iterations])/objective_difference_iter0_val, label='Learned Convolution', linestyle='-', color='black', linewidth=3)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_nag)[iterations])/objective_difference_iter0_val, label='NAG', linestyle='-', color='magenta', linewidth=3)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_one_over_l)[iterations])/objective_difference_iter0_val, label='GD', linestyle='-', color='orange', linewidth=3)
        plt.xlabel('Iteration', fontsize=24)
        plt.ylabel('Optimality in Function Value', fontsize=30)
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.ylim(1e-8, 1)  
        plt.legend(fontsize=20, loc='lower left')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show()


        iterations = np.arange(0, len(obj_val_convolution))
        iteration_regularised_pointwise = len(dict_pointwise)-1
        iteration_regularised_scalar = len(dict_alpha)-1
        
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min_val + np.array(obj_val_scalar)[iterations])/objective_difference_iter0_val, label='Learned Scalar - \nGuaranteed Convergence', linestyle='-', color='green', linewidth=3)
        plt.semilogy(iterations, (-approx_min_val + np.array(obj_val_scalar_risk)[iterations])/objective_difference_iter0_val, label='Learned Scalar - \nNo Guaranteed Convergence', linestyle='-', color='blue', linewidth=3)
        plt.axvline(x=iteration_regularised_scalar, color='black', linestyle='--')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.ylim(5e-5, 1) 
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.legend(fontsize=20, loc='upper left')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show()
        
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min_val + np.array(obj_val_pointwise)[iterations])/objective_difference_iter0_val, label='Learned Pointwise - \nGuaranteed Convergence', linestyle='-', color='green', linewidth=3)
        plt.semilogy(iterations, (-approx_min_val + np.array(obj_val_pointwise_risk)[iterations])/objective_difference_iter0_val, label='Learned Pointwise - \nNo Guaranteed Convergence', linestyle='-', color='red', linewidth=3)
        plt.axvline(x=iteration_regularised_pointwise, color='black', linestyle='--')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.ylim(1e-4, 1)  
        plt.legend(fontsize=20, loc='upper left')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show() 
        
    
            
        best_conv_list, best_conv_comparison_list, best_index, worst_conv_list, worst_conv_comparison_list, worst_index = get_best_and_worst_case_convs(obj_val_convolution_indiv, obj_val_nag_indiv, approx_min_val_list)
        iterations_compare = np.arange(0, len(best_conv_comparison_list))
        best_conv_list += [0 for _ in range(len(best_conv_comparison_list) - len(best_conv_list))]
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[best_index].item() + np.array(best_conv_list)[iterations_compare]), label='Learned Convolution', linestyle='-', color='black', linewidth=3)
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[best_index].item() + np.array(best_conv_comparison_list)[iterations_compare]), label='NAG', linestyle='-', color='magenta', linewidth=3)
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.title('Best Case Convolution', fontsize=20)
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.ylim(1e-7, None)
        plt.xlim(0, 900)
        plt.legend(fontsize=20, loc='upper right')
        plt.show()
        
        worst_conv_list += [0 for _ in range(len(best_conv_comparison_list) - len(worst_conv_list))]
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[worst_index].item() + np.array(worst_conv_list)[iterations_compare])/(-approx_min_val_list[worst_index].item() + worst_conv_list[0]), label='Learned Convolution', linestyle='-', color='black', linewidth=3)
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[worst_index].item() + np.array(worst_conv_comparison_list)[iterations_compare])/(-approx_min_val_list[worst_index].item() + worst_conv_list[0]), label='NAG', linestyle='-', color='magenta', linewidth=3)
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.title('Worst Case Convolution', fontsize=20)
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.ylim(1e-7, None)
        plt.xlim(0, 750)
        plt.legend(fontsize=20, loc='upper right')
        plt.show()

        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR CONVOLUTION: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_convolution))/objective_difference_iter0_val, tol)}')
        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR NAG: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_nag))/objective_difference_iter0_val, tol)}')
        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR BFGS: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_lbfgs))/objective_difference_iter0_val, tol)}')
        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR BACKTRACKING: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_backtracking))/objective_difference_iter0_val, tol)}')
        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR SCALAR: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_scalar))/objective_difference_iter0_val, tol)}')   
        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR Pointwise: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_pointwise))/objective_difference_iter0_val, tol)}')
        for tol in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
            print(f'NUM ITERS TO REACH {tol} FOR GD: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_one_over_l))/objective_difference_iter0_val, tol)}')
        