import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from util.objective_functions import fast_huber_TV, fast_huber_grad, huber
from util.utils import apply_p_full, bt_line_search, psnr, num_iters_before_under_tol, get_best_and_worst_case_convs
from util.optimisation_functions import optimisation_p_pointwise, optimisation_alpha, optimisation_convolution, find_appropriate_lambda_alpha, find_appropriate_lambda_pointwise, find_appropriate_lambda_kernel, convolution, lbfgs_all_functions
from util.visualising_learned_matrices import visualise_learned_scalars, visualise_learned_matrices
from util.operators import get_ct_operator

# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

calculate_final_vals = False
test_train = True

# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.01
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 0.0001 # scalar multiplying the regularizer
N_TRAIN = 100 # number of images (functions) to train over
N_VAL = 100 # number of images (functions) to test over
n = 40 ## side length of image (assumed to be square)
n_angles = 90 ## number of angles used in the radon transform

# Get blurring operator and its properties
forward_operator, adjoint_operator, operator_norm, init_recon, W, W_adj = get_ct_operator(n, n_angles)
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')

# Load datasets
try:
    try:
        train_dataset = torch.load(f'train_dataset_CT_{N_TRAIN}.pt')
        val_dataset = torch.load(f'val_dataset_CT_{N_VAL}.pt')
        print('done')
    except:
        train_dataset = torch.load('train_dataset_CT.pt')
        val_dataset = torch.load('val_dataset_CT.pt')
except:
    print('so that you dont have to donwload extra data, lets see how it performs on STL data')
    test_train = False
    import torchvision
    from torch.utils.data import random_split
    dataset = torchvision.datasets.STL10('STL', split='test', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)


def grad_objective(x,y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)#reg_const*W_adj(huber_grad(W(x), huber_const))

def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const))/ x.shape[0]#reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]#data_fit(x, y) + reg_func(x)

def fast_huber_TV_individual(x, delta=0.01):
    ## in a batch it just adds all individuals up
    tv_h = torch.sum(huber(x[:, :, 1:,:]-x[:, :, :-1,:], delta), dim=(1,2,3))
    tv_w = torch.sum(huber(x[:, :, :,1:]-x[:, :, :,:-1], delta), dim=(1,2,3))
    huber_tv = (tv_h+tv_w)
    return reg_const*huber_tv


def objective_function_individual(x, y):
    return 0.5 * torch.norm((forward_operator(x) - y).view(x.shape[0], -1), dim=1) ** 2 + fast_huber_TV_individual(x)


one_over_L = 1 / L


L = [L for _ in range(N_TRAIN)]

dict_kernel = torch.load('learned_operators/kernel_dictionary_CT.pt')
dict_pointwise = torch.load('learned_operators/pointwise_dictionary_CT.pt')
dict_full = {}#torch.load('learned_operators/full_dictionary_CT.pt')
dict_alpha = torch.load('learned_operators/alpha_dictionary_CT.pt')


kernel_height = dict_kernel[0].shape[-2]
kernel_width = dict_kernel[0].shape[-1]

print('NUMBER OF TRAINING ITERATIONS:', len(dict_kernel))

print('LENGTH OF FULL DICTIONARY:', len(dict_full))

def crop_center(image, crop_size):
    N, c, h, w = image.shape
    start_x = w // 2 - crop_size // 2
    start_y = h // 2 - crop_size // 2
    return image[:,:, start_y:start_y + crop_size, start_x:start_x + crop_size]


visualise_learned_scalars(dict_alpha, 2/max(L))


'''

iterations = 1 + np.array(range(len(dict_kernel) - 1))
relative_changes = [(torch.norm(dict_kernel[i + 1] - dict_kernel[i]) / torch.norm(dict_kernel[i])).item() for i in range(len(dict_kernel) - 1)]
plt.figure(figsize=(10, 6))
plt.loglog(iterations, relative_changes, linestyle='-', color='blue', linewidth=2, markersize=6)
plt.grid(True, which="both", ls="--", linewidth=0.5)
#plt.title('Relative Change in Kernel Over Iterations', fontsize=16, fontweight='bold')
plt.xlabel('Iteration', fontsize=14)
plt.ylabel('Relative Change in Kernel', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()
plt.show()

'''


def visualise_learned_matrices(dict_matrices, indices='all', cmap='bwr'):
    if indices == 'all':
        indices = list(range(len(dict_matrices)))
    
    num_subplots = len(indices) + 1
    plt.figure(figsize=(11 * len(indices), 10))
    vmin = min([torch.min(dict_matrices[i]).item() for i in indices])
    vmax = max([torch.max(dict_matrices[i]).item() for i in indices])
    plt.subplots_adjust(wspace=0.05, hspace=0)
    
    for num, i in enumerate(indices):
        plt.subplot(1, num_subplots, num + 1)
        plt.imshow(dict_matrices[i].squeeze().cpu().detach().numpy(), vmin=vmin, vmax=vmax, cmap=cmap)
        plt.title(f'Iteration {i}', fontsize=80)
        plt.axis('off')  # Remove axes text
    
    # Create a new subplot for the colorbar
    cax = plt.subplot(1, num_subplots, num + 2)
    cbar = plt.colorbar(cax=cax)
    
    # Adjust the position and size of the colorbar subplot
    cax_position = cax.get_position()
    cax_width = cax_position.width * 0.1  # Adjust this value to change the width
    cax.set_position([cax_position.x0, cax_position.y0, cax_width, cax_position.height])
    cbar.ax.tick_params(labelsize=80)
    plt.show()
    
visualise_learned_matrices(dict_pointwise, indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
visualise_learned_matrices(dict_kernel, indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
visualise_learned_matrices({i: torch.max(-5*torch.ones_like(dict_kernel[0]), torch.min(5*torch.ones_like(dict_kernel[0]), dict_kernel[i])) for i in dict_kernel}, indices=[0,1,3,5,10])
visualise_learned_matrices(dict_pointwise, indices=[4,5,6,7, 8], cmap='Greens')


final_training_iter = len(dict_kernel) - 1


if __name__ == '__main__':
    
    if test_train:
        for i, data in enumerate(data_loader):
            
            try:
                xs = crop_center(data, n)#.double()
            except:
                xs = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()
            
            # Display the true image
            plt.rcdefaults()

            # Create clean and noisy observations
            clean_observations = forward_operator(xs)
            noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        
            xs = torch.load('true_image_CT0909.pt')
            noisy_observations1 = torch.load('noisy_observations_CT0909.pt')

            # Initial reconstruction from noisy observations
            inpt = init_recon(noisy_observations1)
            
            if calculate_final_vals:
                
                inpt_get_lmbda_pointwise = inpt.clone()
                inpt_get_lmbda_kernel = inpt.clone()
                inpt_get_lmbda_alpha = inpt.clone()
                
                for k in tqdm(range(len(dict_kernel))):
                    inpt_get_lmbda_pointwise = inpt_get_lmbda_pointwise - dict_pointwise[k] * grad_objective(inpt_get_lmbda_pointwise, noisy_observations1)
                    inpt_get_lmbda_kernel = inpt_get_lmbda_kernel - convolution(grad_objective(inpt_get_lmbda_kernel, noisy_observations1), dict_kernel[k])
                    inpt_get_lmbda_alpha = inpt_get_lmbda_alpha - dict_alpha[k] * grad_objective(inpt_get_lmbda_alpha, noisy_observations1)
                
                lmbda_pointwise = find_appropriate_lambda_pointwise(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_pointwise, dict_pointwise[len(dict_pointwise)-1], 0, n, L, tol=0, max_iter=100)
                dict_pointwise[k+1] = optimisation_p_pointwise(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_pointwise, dict_pointwise[len(dict_pointwise)-1], L, 0, max_iter=1000, tol=0, lmbda=lmbda_pointwise)
                
                lmbda_alpha = find_appropriate_lambda_alpha(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_alpha, torch.tensor(dict_alpha[len(dict_alpha)-1]).to(device), 0, n, L, tol=0, max_iter=100)
                dict_alpha[k+1] = optimisation_alpha(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_alpha, torch.tensor(dict_alpha[len(dict_alpha)-1]).to(device), L, 0, max_iter=100, tol=0.0, lmbda=lmbda_alpha).item()
                
                lmbda_kernel = find_appropriate_lambda_kernel(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_kernel, dict_kernel[len(dict_kernel)-1], 0, n, L, kernel_width, kernel_height, tol=0, max_iter=100)
                dict_kernel[k+1] = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_get_lmbda_kernel, dict_kernel[len(dict_kernel)-1], L, 0, kernel_width, kernel_height, max_iter=1000, tol=0.0, lmbda=lmbda_kernel)
                
                torch.save(dict_pointwise, 'pointwise_dictionary_CT0308-2_lambda.pt')
                torch.save(dict_kernel, 'kernel_dictionary_CT0308-2_lambda.pt')
                torch.save(dict_alpha, 'alpha_dictionary_CT0308-2_lambda.pt')
            
            
            # Initialize objective values for different methods
            obj_pointwise = [objective_function(inpt, noisy_observations1).item()]
            obj_full = [objective_function(inpt, noisy_observations1).item()]
            obj_convolution = [objective_function(inpt, noisy_observations1).item()]
            obj_scalar = [objective_function(inpt, noisy_observations1).item()]

            # Clone initial reconstruction for different methods
            inpt_pointwise = inpt.clone()
            inpt_full = inpt.clone()
            inpt_convolution = inpt.clone()
            inpt_scalar = inpt.clone()

            # Variables for NAG method
            tnew = 0
            told = 0
            obj_nag = [objective_function(inpt, noisy_observations1).item()]
            obj_nag_indiv = [[i.item() for i in objective_function_individual(inpt, noisy_observations1)]]
            psnr_nag = [psnr(inpt, xs)]
            inpt_nag = inpt.clone()
            inpt_nag_m1 = inpt.clone()


            for k in tqdm(range(201)):
                tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
                alphat = (told - 1) / tnew
                yk = inpt_nag + alphat * (inpt_nag - inpt_nag_m1)
                grad_nag = grad_objective(yk, noisy_observations1)
                inpt_nag, inpt_nag_m1 = yk - one_over_L * grad_nag, inpt_nag
                obj_nag.append(objective_function(inpt_nag, noisy_observations1).item())
                obj_nag_indiv.append([i.item() for i in objective_function_individual(inpt_nag, noisy_observations1)])
                

                grad_pointwise = grad_objective(inpt_pointwise, noisy_observations1)
                grad_full = grad_objective(inpt_full, noisy_observations1)
                grad_convolution = grad_objective(inpt_convolution, noisy_observations1)
                grad_scalar = grad_objective(inpt_scalar, noisy_observations1)

                # Update inputs using different methods
                idx = min(k, len(dict_pointwise)-1)
                idx_kernel = min(k, len(dict_kernel)-1)
                inpt_pointwise = inpt_pointwise - dict_pointwise[idx] * grad_pointwise
                inpt_scalar = inpt_scalar - dict_alpha[idx] * grad_scalar
                inpt_convolution = inpt_convolution - convolution(grad_convolution, dict_kernel[idx_kernel])
                if k < len(dict_full):
                    inpt_full = inpt_full - apply_p_full(dict_full[k], grad_full)
                
                # Append objective values
                obj_pointwise.append(objective_function(inpt_pointwise, noisy_observations1).item())
                obj_convolution.append(objective_function(inpt_convolution, noisy_observations1).item())
                obj_scalar.append(objective_function(inpt_scalar, noisy_observations1).item())
                obj_full.append(objective_function(inpt_full, noisy_observations1).item())
                
            approx_min = min([min(obj_nag), min(obj_pointwise), min(obj_full), min(obj_convolution), min(obj_scalar)])
            approx_min_list = objective_function_individual(inpt_convolution, noisy_observations1)
            objective_difference_iter0 = objective_function(inpt, noisy_observations1).item() - approx_min
    
    
    for i, data in enumerate(val_dataloader):
        
        try:
            xs_val = crop_center(data, n)#.double()
        except:
            xs_val = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()
        
        plt.rcdefaults()

        # Create clean and noisy observations
        clean_observations_val = forward_operator(xs_val)
        noisy_observations1_val = clean_observations_val + (noise_level * torch.randn_like(clean_observations_val) * torch.mean(torch.abs(clean_observations_val).view(xs_val.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))


        # Initial reconstruction from noisy observations
        inpt_val = init_recon(noisy_observations1_val)

        # Initialize objective values for different methods
        obj_val_pointwise = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_full = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_scalar = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_one_over_l = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_backtracking = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_pointwise_risk = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution_risk = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_scalar_risk = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_convolution_indiv = [[i.item() for i in objective_function_individual(inpt_val, noisy_observations1_val)]]
    
        psnr_pointwise = [psnr(inpt_val, xs_val)]
        psnr_full = [psnr(inpt_val, xs_val)]
        psnr_convolution = [psnr(inpt_val, xs_val)]
        psnr_scalar = [psnr(inpt_val, xs_val)]
        psnr_one_over_l = [psnr(inpt_val, xs_val)]
        psnr_backtracking = [psnr(inpt_val, xs_val)]
        psnr_BFGS = [psnr(inpt_val, xs_val)]
        

        # Clone initial reconstruction for different methods
        inpt_val_pointwise = inpt_val.clone()
        inpt_val_full = inpt_val.clone()
        inpt_val_convolution = inpt_val.clone()
        inpt_val_scalar = inpt_val.clone()
        inpt_val_one_over_l = inpt_val.clone()
        inpt_val_backtracking = inpt_val.clone()
        inpt_val_alpha_fallback = inpt_val.clone()
        inpt_val_pointwise_fallback = inpt_val.clone()
        inpt_val_convolution_fallback = inpt_val.clone()
        inpt_val_pointwise_risk = inpt_val.clone()
        inpt_val_convolution_risk = inpt_val.clone()
        inpt_val_scalar_risk = inpt_val.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_val_nag = [objective_function(inpt_val, noisy_observations1_val).item()]
        obj_val_nag_indiv = [[i.item() for i in objective_function_individual(inpt_val, noisy_observations1_val)]]
        psnr_nag = [psnr(inpt_val, xs_val)]
        inpt_val_nag = inpt_val.clone()
        inpt_val_nag_m1 = inpt_val.clone()

        for k in tqdm(range(501)):
            #print(objective_function(inpt_val_full, noisy_observations1_val).item())
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_val_nag + alphat * (inpt_val_nag - inpt_val_nag_m1)
            grad_nag = grad_objective(yk, noisy_observations1_val)
            inpt_val_nag, inpt_val_nag_m1 = yk - one_over_L * grad_nag, inpt_val_nag
            obj_val_nag.append(objective_function(inpt_val_nag, noisy_observations1_val).item())
            obj_val_nag_indiv.append([i.item() for i in objective_function_individual(inpt_val_nag, noisy_observations1_val)])

            grad_pointwise = grad_objective(inpt_val_pointwise, noisy_observations1_val)
            grad_full = grad_objective(inpt_val_full, noisy_observations1_val)
            grad_convolution = grad_objective(inpt_val_convolution, noisy_observations1_val)
            grad_scalar = grad_objective(inpt_val_scalar, noisy_observations1_val)
            grad_one_over_l = grad_objective(inpt_val_one_over_l, noisy_observations1_val)
            grad_backtracking = grad_objective(inpt_val_backtracking, noisy_observations1_val)
            grad_alpha_fallback = grad_objective(inpt_val_alpha_fallback, noisy_observations1_val)
            grad_pointwise_fallback = grad_objective(inpt_val_pointwise_fallback, noisy_observations1_val)
            grad_convolution_fallback = grad_objective(inpt_val_convolution_fallback, noisy_observations1_val)
            grad_pointwise_risk = grad_objective(inpt_val_pointwise_risk, noisy_observations1_val)
            grad_convolution_risk = grad_objective(inpt_val_convolution_risk, noisy_observations1_val)
            grad_scalar_risk = grad_objective(inpt_val_scalar_risk, noisy_observations1_val)

            # Update inputs using different methods
            idx = min(k, len(dict_kernel)-1)
            idx_full = min(k, len(dict_full)-1)
            idx_risk = min(k, len(dict_pointwise)-2)
            inpt_val_pointwise = inpt_val_pointwise - dict_pointwise[idx] * grad_pointwise
            inpt_val_scalar = inpt_val_scalar - dict_alpha[idx] * grad_scalar
            inpt_val_convolution = inpt_val_convolution - convolution(grad_convolution, dict_kernel[idx])
            if k < len(dict_full):
                inpt_val_full = inpt_val_full - apply_p_full(dict_full[k], grad_full)
                obj_val_full.append(objective_function(inpt_val_full, noisy_observations1_val).item())
            inpt_val_one_over_l = inpt_val_one_over_l - one_over_L * grad_one_over_l 
            inpt_val_pointwise_risk = inpt_val_pointwise_risk - dict_pointwise[idx_risk] * grad_pointwise_risk
            inpt_val_convolution_risk = inpt_val_convolution_risk - convolution(grad_convolution_risk, dict_kernel[idx_risk])
            inpt_val_scalar_risk = inpt_val_scalar_risk - dict_alpha[idx_risk] * grad_scalar_risk
            if k < len(dict_kernel):
                inpt_val_alpha_fallback = inpt_val_alpha_fallback - dict_alpha[k] * grad_alpha_fallback
                inpt_val_pointwise_fallback = inpt_val_pointwise_fallback - dict_pointwise[k] * grad_pointwise_fallback
                inpt_val_convolution_fallback = inpt_val_convolution_fallback - convolution(grad_convolution_fallback, dict_kernel[k])
                obj_val_convolution_indiv.append([i.item() for i in objective_function_individual(inpt_val_convolution, noisy_observations1_val)])
                    
            else:
                inpt_val_alpha_fallback = inpt_val_alpha_fallback - one_over_L * grad_alpha_fallback
                inpt_val_pointwise_fallback = inpt_val_pointwise_fallback - one_over_L * grad_pointwise_fallback
                inpt_val_convolution_fallback = inpt_val_convolution_fallback - one_over_L * grad_convolution_fallback
            # alpha_bt = torch.stack([torch.tensor(bt_line_search(lambda x: objective_function(x, noisy_observations1_val[i].unsqueeze(0)), lambda x: grad_objective(x, noisy_observations1_val[i].unsqueeze(0)), inpt_val_backtracking[i].unsqueeze(0), -grad_backtracking[i].unsqueeze(0), t=5, beta=0.5, alpha=1e-1)).to(device) for i in range(inpt_val.shape[0])]).unsqueeze(1).unsqueeze(1).unsqueeze(1)
            # inpt_val_backtracking = inpt_val_backtracking - alpha_bt * grad_backtracking
            
            # Append objective values
            obj_val_pointwise.append(objective_function(inpt_val_pointwise, noisy_observations1_val).item())
            obj_val_convolution.append(objective_function(inpt_val_convolution, noisy_observations1_val).item())
            obj_val_scalar.append(objective_function(inpt_val_scalar, noisy_observations1_val).item())
            obj_val_one_over_l.append(objective_function(inpt_val_one_over_l, noisy_observations1_val).item())
            obj_val_backtracking.append(objective_function(inpt_val_backtracking, noisy_observations1_val).item())
            obj_val_pointwise_risk.append(objective_function(inpt_val_pointwise_risk, noisy_observations1_val).item())
            obj_val_convolution_risk.append(objective_function(inpt_val_convolution_risk, noisy_observations1_val).item())
            obj_val_scalar_risk.append(objective_function(inpt_val_scalar_risk, noisy_observations1_val).item())

            
            
            # if k in [3, 5, 7, 10, 13, 15, 17, 20]:
            #     print('\n\n\n\n\n\n\n\n\n')
            #     print(k)
            #     psnrs_conv = [psnr(inpt_val_convolution[zz,:,:,:], xs_val[zz,:,:,:]) for zz in range(xs_val.shape[0])]
            #     psnrs_nag = [psnr(inpt_val_nag[zz,:,:,:], xs_val[zz,:,:,:]) for zz in range(xs_val.shape[0])]
            #     diffs = [psnrs_conv[zz] - psnrs_nag[zz] for zz in range(xs_val.shape[0])]
            #     max_diff_index = diffs.index(max(diffs))
            #     print(max_diff_index)
            #     print(f'PSNRS: {psnrs_conv[max_diff_index]}, {psnrs_nag[max_diff_index]}')
            #     ## now plot the images and the reconstructions
            #     plt.imshow(xs_val[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('True Image')
            #     plt.show()
                
  
                
            #     plt.imshow(noisy_observations1_val[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('Noisy Observations')
            #     plt.show()
                
            #     plt.imshow(inpt_val_convolution[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('Convolution Reconstruction')
            #     plt.show()
                
            #     plt.imshow(inpt_val_nag[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title('NAG Reconstruction')
            #     plt.show()
                
            #     print(f'PSNR Convolution: {psnrs_conv[max_diff_index]}')
            #     psnr_compare = psnrs_conv[max_diff_index]
            #     if k == 30:
            #         diff_idx = max_diff_index
            # if k > 30 and k % 5 == 0:
            #     plt.imshow(inpt_val_nag[max_diff_index, :, :, :].squeeze().cpu().numpy(), cmap='gray')
            #     plt.title(f'NAG Reconstruction - {k} iterations')
            #     plt.show()
            #     print(f'PSNR NAG: {psnr(inpt_val_nag[max_diff_index,:,:,:], xs_val[max_diff_index,:,:,:])}, PSNR Convolution: {psnr_compare}')


        plt.imshow(inpt_val_nag[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Final Reconstruction')
        plt.show()
            
        
        #obj_val_lbfgs = lbfgs_all_functions(objective_function, grad_objective, inpt_val, noisy_observations1_val, L[0], memory=10, max_iter=len(dict_kernel))

        approx_min_val = min(obj_val_nag)
        approx_min_val_list = objective_function_individual(inpt_val_nag, noisy_observations1_val)

        objective_difference_iter0_val = objective_function(inpt_val, noisy_observations1_val).item() - approx_min_val

        #### GENERALISATION PLOT:
        if test_train:
            iterations = np.arange(0, len(dict_kernel))
            plt.figure(figsize=(10, 8), dpi=300)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_pointwise)[iterations])/objective_difference_iter0, label='Learned Pointwise - Train', linestyle='-', color='red', linewidth=2)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_scalar)[iterations])/objective_difference_iter0, label='Learned Scalar - Train', linestyle='-', color='blue', linewidth=2)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution)[iterations])/objective_difference_iter0, label='Learned Convolution - Train', linestyle='-', color='black', linewidth=2)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_pointwise)[iterations])/objective_difference_iter0_val, label='Learned Pointwise - Test', linestyle='--', color='red', linewidth=4)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_scalar)[iterations])/objective_difference_iter0_val, label='Learned Scalar - Test', linestyle='--', color='blue', linewidth=4)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_convolution)[iterations])/objective_difference_iter0_val, label='Learned Convolution - Test', linestyle='--', color='black', linewidth=4)
            plt.xlabel(r'Iteration $t$', fontsize=24)
            plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
            plt.ylim(1e-10, None)  
            plt.legend(fontsize=24, loc='upper right')
            plt.tick_params(axis='both', which='major', labelsize=24)
            plt.grid(True, linestyle='-', linewidth=1)
            plt.tight_layout()
            plt.show()
            
        # Adjust iterations array
        iterations = np.arange(final_training_iter, len(obj_val_pointwise))

        # Plot convergence of different methods
        iterations = np.arange(0, len(dict_kernel))
        plt.figure(figsize=(10, 8), dpi=300)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_pointwise)[iterations])/objective_difference_iter0_val, label='Learned Pointwise', linestyle='-', color='red')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_scalar)[iterations])/objective_difference_iter0_val, label='Learned Scalar', linestyle='-', color='blue')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_convolution)[iterations])/objective_difference_iter0_val, label='Learned Convolution', linestyle='-', color='black')
        #plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_full)[iterations])/objective_difference_iter0_val, label='Learned Full', linestyle='-', color='brown')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_nag)[iterations])/objective_difference_iter0_val, label='NAG', linestyle='-', color='magenta')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_one_over_l)[iterations])/objective_difference_iter0_val, label='GD', linestyle='-', color='orange')
        #plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_lbfgs)[iterations])/objective_difference_iter0_val, label='L-BFGS', linestyle='-', color='green')
        #plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_backtracking)[iterations])/objective_difference_iter0_val, label='Backtracking Line Search', linestyle='-', color='purple')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.ylim(1e-10, None) 
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.legend(fontsize=20, loc='upper right')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show()

        
        '''
        iterations = np.arange(len(dict_kernel), len(obj_val_convolution))
        plt.figure(figsize=(10, 8), dpi=300)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_convolution)[iterations])/objective_difference_iter0_val, label=f'Convolution Regularisation - {kernel_height} x {kernel_width}', linestyle='-', color='black')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_convolution_risk)[iterations])/objective_difference_iter0_val, label='Learned Convolution', linestyle=':', color='black')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$ F(x_t) - F(x^*)$', fontsize=14)
        plt.ylim(1e-10, None) 
        plt.legend(fontsize=20, loc='lower left')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show()
        
        
        plt.figure(figsize=(10, 8), dpi=300)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_scalar)[iterations])/objective_difference_iter0_val, label='Scalar Regularisation', linestyle='-', color='blue')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_scalar_risk)[iterations])/objective_difference_iter0_val, label='Learned Scalar', linestyle=':', color='blue')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$ F(x_t) - F(x^*)$', fontsize=14)
        plt.ylim(1e-10, None)  
        plt.legend(fontsize=20, loc='lower left')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show()
        
        plt.figure(figsize=(10, 8), dpi=300)
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_pointwise)[iterations])/objective_difference_iter0_val, label='Pointwise Regularisation', linestyle='-', color='red')
        plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_pointwise_risk)[iterations])/objective_difference_iter0_val, label='Learned Pointwise', linestyle=':', color='red')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$ F(x_t) - F(x^*)$', fontsize=14)
        plt.ylim(1e-10, None)  
        plt.legend(fontsize=20, loc='lower left')
        plt.grid(True, linestyle='-', linewidth=1)
        plt.tight_layout()
        plt.show()

        '''
        
            
        best_conv_list, best_conv_comparison_list, best_index, worst_conv_list, worst_conv_comparison_list, worst_index = get_best_and_worst_case_convs(obj_val_convolution_indiv, obj_val_nag_indiv, approx_min_val_list)
        iterations_compare = np.arange(0, len(best_conv_comparison_list))
        best_conv_list += [0 for _ in range(len(best_conv_comparison_list) - len(best_conv_list))]
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[best_index].item() + np.array(best_conv_list)[iterations_compare]), label='Learned Convolution', linestyle='-', color='black')
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[best_index].item() + np.array(best_conv_comparison_list)[iterations_compare]), label='NAG', linestyle='-', color='magenta')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.title('Best Case Convolution', fontsize=20)
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.ylim(1e-7, None)
        plt.legend(fontsize=20, loc='upper right')
        plt.xlim(0, 150)
        plt.show()
        
        worst_conv_list += [0 for _ in range(len(best_conv_comparison_list) - len(worst_conv_list))]
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[worst_index].item() + np.array(worst_conv_list)[iterations_compare]), label='Learned Convolution', linestyle='-', color='black')
        plt.semilogy(1+iterations_compare, (-approx_min_val_list[worst_index].item() + np.array(worst_conv_comparison_list)[iterations_compare]), label='NAG', linestyle='-', color='magenta')
        plt.xlabel(r'Iteration $t$', fontsize=24)
        plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
        plt.title('Worst Case Convolution', fontsize=20)
        plt.tick_params(axis='both', which='major', labelsize=24)
        plt.ylim(1e-7, None)
        plt.legend(fontsize=20, loc='upper right')
        plt.xlim(0, 120)
        plt.show()


        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR CONVOLUTION: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_convolution))/objective_difference_iter0, tol)}')
        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR NAG: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_nag))/objective_difference_iter0, tol)}')
        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR BFGS: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_lbfgs))/objective_difference_iter0, tol)}')
        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR BACKTRACKING: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_backtracking))/objective_difference_iter0, tol)}')
        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR SCALAR: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_scalar))/objective_difference_iter0, tol)}')   
        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR Pointwise: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_pointwise))/objective_difference_iter0, tol)}')
        for tol in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]:
            print(f'NUM ITERS TO REACH {tol} FOR GD: {num_iters_before_under_tol((-approx_min_val + np.array(obj_val_one_over_l))/objective_difference_iter0, tol)}')
