import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import random_split
from util.objective_functions import fast_huber_TV, fast_huber_grad
from util.optimisation_functions import optimisation_convolution, convolution
import torchvision

# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.0025
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 1e-05 # scalar multiplying the regularizer
N_TRAIN = 100 # number of images (functions) to train over
N_VAL = 100 # number of images (functions) to test over
n = 96 ## side length of image (assumed to be square)

# defining the kernel for the forward operator
def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
    """Creates a 2D Gaussian kernel."""
    x = torch.arange(size).float() - size // 2
    x = x.view(-1, 1)
    y = torch.arange(size).float() - size // 2
    y = y.view(1, -1)
    kernel_2d = torch.exp(-0.5 * (x**2 + y**2) / sigma**2)
    kernel_2d = kernel_2d / kernel_2d.sum()
    return kernel_2d.unsqueeze(0).unsqueeze(0)

size = 5
sigma = 1.5
gaussian_kernel = gaussian_kernel(size, sigma).to(device)

def forward_operator(x):
    return convolution(x, gaussian_kernel)

adjoint_operator = forward_operator

def init_recon(x):
    return x

operator_norm = 1.0

## Calculate the Lipschitz constant of gradient
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')

try:
    train_dataset = torch.load(f'train_dataset_blur_{N_TRAIN}.pt')
    val_dataset = torch.load(f'val_dataset_blur_{N_VAL}.pt')
except:
    dataset = torchvision.datasets.STL10('STL', split='train', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])

    torch.save(train_dataset, f'train_dataset_blur_{N_TRAIN}.pt')
    torch.save(val_dataset, f'val_dataset_blur_{N_VAL}.pt')

# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)

# Define the overall objective function - mean over batch size
def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]

# Gradient functions for the objective function - without dividning by x.shape[0] - this information is absorbed into L
def grad_objective(x, y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)


# Initialize preconditioner and other variables
one_over_L = 1 / L

## try to load pre-learned dictionaries, else start new
try:
    dict_kernel_3 = torch.load('learned_operators/kernel_dictionary_blur_3.pt')
    dict_kernel_5 = torch.load('learned_operators/kernel_dictionary_blur_5.pt')
    dict_kernel_7 = torch.load('learned_operators/kernel_dictionary_blur_7.pt')
    dict_kernel_11 = torch.load('learned_operators/kernel_dictionary_blur_11.pt')
    dict_kernel = torch.load('learned_operators/kernel_dictionary_blur.pt')
except:
    dict_kernel_3 = {}
    dict_kernel_5 = {}
    dict_kernel_7 = {}
    dict_kernel_11 = {}
    dict_kernel = {}

L = [L] * N_TRAIN
L_max = L[0]

## Setting sizes of kernels
kernel_width_3 = 3
kernel_height_3 = kernel_width_3

kernel_width_5 = 5
kernel_height_5 = kernel_width_5

kernel_width_7 = 7
kernel_height_7 = kernel_width_7

kernel_width_11 = 11
kernel_height_11 = kernel_width_11
            
            
if __name__ == '__main__':
    
    for j, val_data in enumerate(val_dataloader):
        x_val, labels_val = val_data
        x_val = x_val.mean(dim=1, keepdim=True).to(device)
        clean_observations_val = forward_operator(x_val)
        noisy_observations_val = clean_observations_val + (noise_level * torch.randn_like(clean_observations_val) * torch.mean(torch.abs(clean_observations_val).view(x_val.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        inpt_val = init_recon(noisy_observations_val)

        inpt_convolution_val = inpt_val.clone()
        inpt_convolution_3_val = inpt_val.clone()
        inpt_convolution_5_val = inpt_val.clone()
        inpt_convolution_7_val = inpt_val.clone()
        inpt_convolution_11_val = inpt_val.clone()
        
        obj_convolution_val = [objective_function(inpt_val, noisy_observations_val).item()]
        obj_convolution_3_val = [objective_function(inpt_convolution_3_val, noisy_observations_val).item()]
        obj_convolution_5_val = [objective_function(inpt_convolution_5_val, noisy_observations_val).item()]
        obj_convolution_7_val = [objective_function(inpt_convolution_7_val, noisy_observations_val).item()]
        obj_convolution_11_val = [objective_function(inpt_convolution_11_val, noisy_observations_val).item()]
    
        tnew = 0
        told = 0
        obj_agd_val = [objective_function(inpt_val, noisy_observations_val).item()]
        inpt_agd = inpt_val.clone()
        inpt_agd_m1 = inpt_val.clone()

        # Accelerated Gradient Descent (AGD) method
        for k in range(1001):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_agd + alphat * (inpt_agd - inpt_agd_m1)
            grad_agd = grad_objective(yk, noisy_observations_val)
            inpt_agd, inpt_agd_m1 = yk - one_over_L * grad_agd, inpt_agd
            obj_agd_val.append(objective_function(inpt_agd, noisy_observations_val).item())
        
        # Approximate minimum value of the objective function
        approx_min_val = min(obj_agd_val) 
        
                
    for i, data in enumerate(data_loader):
        # Preprocess the images
        xs, labels = data
        xs = xs.mean(dim=1, keepdim=True).to(device)
        #xs = F.interpolate(xs, size=(n, n), mode='bilinear', align_corners=False)
        
        # Display the true image
        plt.imshow(xs[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('True Image')
        plt.show()

        # Create clean and noisy observations
        clean_observations = forward_operator(xs)
        noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        
        # Display the noisy observation
        plt.imshow(noisy_observations1[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Noisy Observations')
        plt.show()

        # Initial reconstruction from noisy observations
        inpt = init_recon(noisy_observations1)

        # Display the initial reconstruction
        plt.imshow(inpt[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Initial Reconstruction')
        plt.show()

        obj_convolution = [objective_function(inpt, noisy_observations1).item()]
        obj_convolution_3 = [objective_function(inpt, noisy_observations1).item()]
        obj_convolution_5 = [objective_function(inpt, noisy_observations1).item()]
        obj_convolution_7 = [objective_function(inpt, noisy_observations1).item()]
        obj_convolution_11 = [objective_function(inpt, noisy_observations1).item()]

        inpt_convolution = inpt.clone()
        inpt_convolution_3 = inpt.clone()
        inpt_convolution_5 = inpt.clone()
        inpt_convolution_7 = inpt.clone()
        inpt_convolution_11 = inpt.clone()

        # Variables for AGD method
        tnew = 0
        told = 0
        obj_agd = [objective_function(inpt, noisy_observations1).item()]
        inpt_agd = inpt.clone()
        inpt_agd_m1 = inpt.clone()

        # Accelerated Gradient Descent (AGD) method
        for k in range(1001):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_agd + alphat * (inpt_agd - inpt_agd_m1)
            grad_agd = grad_objective(yk, noisy_observations1)
            inpt_agd, inpt_agd_m1 = yk - one_over_L * grad_agd, inpt_agd
            obj_agd.append(objective_function(inpt_agd, noisy_observations1).item())

        # Final approximation
        plt.imshow(inpt_agd[0, :, :, :].squeeze().detach().cpu().numpy(), cmap='gray')
        plt.title('Final Reconstruction - AGD')
        plt.show()
        
        # Plot convergence of AGD
        approx_min = min(obj_agd) 
        iterations = np.arange(0, len(obj_agd))
        marker_every = 10
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min + np.array(obj_agd)), label='Accelerated Gradient Descent', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
        plt.xlabel('Iteration $k$', fontsize=14)
        plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
        plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
        plt.legend(fontsize=12, loc='upper right')
        plt.grid(True, which="both", linestyle='--', linewidth=0.5)
        plt.tight_layout()
        plt.show()

        # Main training loop for preconditioning methods
        for k in tqdm(range(training_iterations)):
            grad_convolution = grad_objective(inpt_convolution, noisy_observations1)
            grad_convolution_3 = grad_objective(inpt_convolution_3, noisy_observations1)
            grad_convolution_5 = grad_objective(inpt_convolution_5, noisy_observations1)
            grad_convolution_7 = grad_objective(inpt_convolution_7, noisy_observations1)
            grad_convolution_11 = grad_objective(inpt_convolution_11, noisy_observations1)
            
            
            if k == 0:
                kernel_3 = torch.zeros((1, 1, kernel_width_3, kernel_height_3)).to(device)
                kernel_3[0, 0, kernel_width_3 // 2, kernel_height_3 // 2] = 1 / L_max
                
                kernel_5 = torch.zeros((1, 1, kernel_width_5, kernel_height_5)).to(device)
                kernel_5[0, 0, kernel_width_5 // 2, kernel_height_5 // 2] = 1 / L_max
                
                kernel_7 = torch.zeros((1, 1, kernel_width_7, kernel_height_7)).to(device)
                kernel_7[0, 0, kernel_width_7 // 2, kernel_height_7 // 2] = 1 / L_max
                
                kernel_11 = torch.zeros((1, 1, kernel_width_11, kernel_height_11)).to(device)
                kernel_11[0, 0, kernel_width_11 // 2, kernel_height_11 // 2] = 1 / L_max
                
                kernel = torch.zeros((1, 1, n, n)).to(device)
                kernel[0, 0, n // 2 - 1, n // 2 - 1] = 1 / L_max
            
            max_iter_conv = 5000


            
            try:
                kernel_3 = dict_kernel_3[k]
            except:
                print('calc kernel 3')
                kernel_3 = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution_3, kernel_3, L, approx_min, kernel_width_3, kernel_height_3, max_iter=max_iter_conv, tol=0.0001, manual=False, verbose=True)
                dict_kernel_3[k] = kernel_3
                #torch.save(dict_kernel_3, 'kernel_dictionary_blur_3_0609-4.pt')
            

            try:
                kernel_5 = dict_kernel_5[k]
            except:
                print('calc kernel 5')
                kernel_5 = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution_5, kernel_5, L, approx_min, kernel_width_5, kernel_height_5, max_iter=max_iter_conv, tol=0.0001, manual=False, verbose=True)
                dict_kernel_5[k] = kernel_5
                #torch.save(dict_kernel_5, 'kernel_dictionary_blur_5_0609-4.pt')
            
            try:
                kernel_7 = dict_kernel_7[k]
            except:
                print('calc kernel 7')
                kernel_7 = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution_7, kernel_7, L, approx_min, kernel_width_7, kernel_height_7, max_iter=max_iter_conv, tol=0.0001, manual=False, verbose=True)
                dict_kernel_7[k] = kernel_7
                #torch.save(dict_kernel_7, 'kernel_dictionary_blur_7_0609-4.pt')
            
            try:
                kernel_11 = dict_kernel_11[k]
            except:
                print('calc kernel 11')
                kernel_11 = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution_11, kernel_11, L, approx_min, kernel_width_11, kernel_height_11, max_iter=max_iter_conv, tol=0.0001, manual=False, verbose=True)
                dict_kernel_11[k] = kernel_11
                #torch.save(dict_kernel_11, 'kernel_dictionary_blur_11_0609-4.pt')

            try:
                kernel = dict_kernel[k]
            except:
                print('calc kernel')
                kernel = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution, kernel, L, approx_min, n, n, max_iter=max_iter_conv, tol=0.0001, manual=False, verbose=True)
                dict_kernel[k] = kernel
                #torch.save(dict_kernel, 'kernel_dictionary_blur_0609-2.pt')
        

            inpt_convolution_3 = inpt_convolution_3 - convolution(grad_convolution_3, kernel_3)
            inpt_convolution = inpt_convolution - convolution(grad_convolution, kernel)
            inpt_convolution_5 = inpt_convolution_5 - convolution(grad_convolution_5, kernel_5)
            inpt_convolution_7 = inpt_convolution_7 - convolution(grad_convolution_7, kernel_7)
            inpt_convolution_11 = inpt_convolution_11 - convolution(grad_convolution_11, kernel_11)
            
            obj_convolution_3.append(objective_function(inpt_convolution_3, noisy_observations1).item())
            obj_convolution.append(objective_function(inpt_convolution, noisy_observations1).item())
            obj_convolution_5.append(objective_function(inpt_convolution_5, noisy_observations1).item())
            obj_convolution_7.append(objective_function(inpt_convolution_7, noisy_observations1).item())
            obj_convolution_11.append(objective_function(inpt_convolution_11, noisy_observations1).item())
            

            grad_convolution_val = grad_objective(inpt_convolution_val, noisy_observations_val)
            inpt_convolution_val = inpt_convolution_val - convolution(grad_convolution_val, kernel)
            obj_convolution_val.append(objective_function(inpt_convolution_val, noisy_observations_val).item())
                            
            grad_convolution_5_val = grad_objective(inpt_convolution_5_val, noisy_observations_val)
            inpt_convolution_5_val = inpt_convolution_5_val - convolution(grad_convolution_5_val, kernel_5)
            obj_convolution_5_val.append(objective_function(inpt_convolution_5_val, noisy_observations_val).item())
            
            grad_convolution_7_val = grad_objective(inpt_convolution_7_val, noisy_observations_val)
            inpt_convolution_7_val = inpt_convolution_7_val - convolution(grad_convolution_7_val, kernel_7)
            obj_convolution_7_val.append(objective_function(inpt_convolution_7_val, noisy_observations_val).item())
            
            grad_convolution_11_val = grad_objective(inpt_convolution_11_val, noisy_observations_val)
            inpt_convolution_11_val = inpt_convolution_11_val - convolution(grad_convolution_11_val, kernel_11)
            obj_convolution_11_val.append(objective_function(inpt_convolution_11_val, noisy_observations_val).item())        

            grad_convolution_3_val = grad_objective(inpt_convolution_3_val, noisy_observations_val)
            inpt_convolution_3_val = inpt_convolution_3_val - convolution(grad_convolution_3_val, kernel_3)
            obj_convolution_3_val.append(objective_function(inpt_convolution_3_val, noisy_observations_val).item())
            
            if k % 50 == 0:
                
                plt.imshow(kernel_3.squeeze().detach().cpu().numpy())
                plt.colorbar()
                plt.title(f'Kernel 3 Iteration {k}')
                plt.show()
                
                plt.imshow(kernel_5.squeeze().detach().cpu().numpy())
                plt.colorbar()
                plt.title(f'Kernel 5 Iteration {k}')
                plt.show()
                
                plt.imshow(kernel_7.squeeze().detach().cpu().numpy())
                plt.colorbar()
                plt.title(f'Kernel 7 Iteration {k}')
                plt.show()
                
                plt.imshow(kernel_11.squeeze().detach().cpu().numpy())
                plt.colorbar()
                plt.title(f'Kernel 11 Iteration {k}')
                plt.show()
            
                objective_difference_iter0 = objective_function(inpt, noisy_observations1).item() - approx_min
                objective_difference_iter0_val = objective_function(inpt_val, noisy_observations_val).item() - approx_min_val
                
                iterations = np.arange(0, len(obj_convolution_3))
                len_3 = len(obj_convolution_3)
                
                plt.figure(figsize=(10, 8), dpi=300)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_agd)[iterations])/objective_difference_iter0, label='AGD', linestyle='-', color='magenta', linewidth=3)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution_3)[iterations])/objective_difference_iter0, label=f'{kernel_height_3} x {kernel_width_3}', linestyle='-', color='red', linewidth=3)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution_5 + [0] * (len_3 - len(obj_convolution_5)))[iterations])/objective_difference_iter0, label=f'{kernel_height_5} x {kernel_width_5}', linestyle='-', color='blue', linewidth=3)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution_7 + [0] * (len_3 - len(obj_convolution_7)))[iterations])/objective_difference_iter0, label=f'{kernel_height_7} x {kernel_width_7}', linestyle='-', color='purple', linewidth=3)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution_11 + [0] * (len_3 - len(obj_convolution_11)))[iterations])/objective_difference_iter0, label=f'{kernel_height_11} x {kernel_width_11}', linestyle='-', color='orange', linewidth=3)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution + [0] * (len_3 - len(obj_convolution)))[iterations])/objective_difference_iter0, label=f'{n} x {n}', linestyle='-', color='black', linewidth=3)
                plt.xlabel(r'Iteration $t$', fontsize=24)
                plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
                plt.ylim(1e-8, None)  # Set lower limit to 10^-8
                plt.tick_params(axis='both', which='major', labelsize=24)
                plt.legend(fontsize=20, loc='lower left')
                plt.grid(True, linestyle='-', linewidth=1)
                plt.tight_layout()
                plt.show()   
                

                iterations_val = np.arange(0, len(obj_convolution_3_val))
                plt.figure(figsize=(10, 8), dpi=300)
                plt.loglog(1+iterations_val, (-approx_min_val + np.array(obj_agd_val)[iterations_val])/objective_difference_iter0_val, label='AGD', linestyle='-', color='magenta', linewidth=3)
                plt.loglog(1+iterations_val, (-approx_min_val + np.array(obj_convolution_3_val)[iterations_val])/objective_difference_iter0_val, label=f'{kernel_height_3} x {kernel_width_3}', linestyle='-', color='red', linewidth=3)
                plt.loglog(1+iterations_val, (-approx_min_val + np.array(obj_convolution_5_val + [0] * (len_3 - len(obj_convolution_5_val)))[iterations_val])/objective_difference_iter0_val, label=f'{kernel_height_5} x {kernel_width_5}', linestyle='-', color='blue', linewidth=3)
                plt.loglog(1+iterations_val, (-approx_min_val + np.array(obj_convolution_7_val + [0] * (len_3 - len(obj_convolution_7_val)))[iterations_val])/objective_difference_iter0_val, label=f'{kernel_height_7} x {kernel_width_7}', linestyle='-', color='purple', linewidth=3)
                plt.loglog(1+iterations_val, (-approx_min_val + np.array(obj_convolution_11_val + [0] * (len_3 - len(obj_convolution_11_val)))[iterations_val])/objective_difference_iter0_val, label=f'{kernel_height_11} x {kernel_width_11}', linestyle='-', color='orange', linewidth=3)
                plt.loglog(1+iterations_val, (-approx_min_val + np.array(obj_convolution_val + [0] * (len_3 - len(obj_convolution_val)))[iterations_val])/objective_difference_iter0_val, label=f'{n} x {n}', linestyle='-', color='black', linewidth=3)
                plt.xlabel(r'Iteration $t$', fontsize=24)
                plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
                plt.ylim(1e-8, None)  
                plt.tick_params(axis='both', which='major', labelsize=24)
                plt.legend(fontsize=20, loc='lower left')
                plt.grid(True, linestyle='-', linewidth=1)
                plt.tight_layout()
                plt.show()   