import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import random_split
from util.objective_functions import fast_huber_TV, fast_huber_grad
from util.optimisation_functions import optimisation_p_pointwise, optimisation_alpha, optimisation_convolution, convolution
import torchvision

# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.0025
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 1e-05 # scalar multiplying the regularizer
N_TRAIN = 100 # number of images (functions) to train over
N_VAL = 100 # number of images (functions) to test over
n = 96 ## side length of image (assumed to be square)

# defining the kernel for the forward operator
def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
    """Creates a 2D Gaussian kernel."""
    x = torch.arange(size).float() - size // 2
    x = x.view(-1, 1)
    y = torch.arange(size).float() - size // 2
    y = y.view(1, -1)
    kernel_2d = torch.exp(-0.5 * (x**2 + y**2) / sigma**2)
    kernel_2d = kernel_2d / kernel_2d.sum()
    return kernel_2d.unsqueeze(0).unsqueeze(0)

size = 5
sigma = 1.5
gaussian_kernel = gaussian_kernel(size, sigma).to(device)

def forward_operator(x):
    return convolution(x, gaussian_kernel)

adjoint_operator = forward_operator 

def init_recon(x):
    return x

operator_norm = 1.0

## Calculate the Lipschitz constant of gradient
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')



try:
    train_dataset = torch.load(f'train_dataset_blur_{N_TRAIN}.pt')
    val_dataset = torch.load(f'val_dataset_blur_{N_VAL}.pt')
except:
    dataset = torchvision.datasets.STL10('STL', split='train', transform=torchvision.transforms.ToTensor(), folds=1, download=True)

    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])

    torch.save(train_dataset, f'train_dataset_blur_{N_TRAIN}.pt')
    torch.save(val_dataset, f'val_dataset_blur_{N_VAL}.pt')

# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)

# Define the overall objective function - mean over batch size
def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]

# Gradient functions for the objective function - without dividning by x.shape[0] - this information is absorbed into L
def grad_objective(x, y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)


one_over_L = 1 / L

## try to load pre-learned dictionaries, else start new
try:
    dict_pointwise = torch.load('learned_operators/pointwise_dictionary_blur_lambda.pt')
    dict_alpha = torch.load('learned_operators/alpha_dictionary_blur.pt')
    dict_kernel = torch.load('learned_operators/kernel_dictionary_blur.pt')
except:
    dict_pointwise = {}
    dict_alpha = {}
    dict_kernel = {}

L = [L] * N_TRAIN
L_max = L[0]


if __name__ == '__main__':
    for i, data in enumerate(data_loader):
        # Preprocess the images
        xs, labels = data
        xs = xs.mean(dim=1, keepdim=True).to(device)
        #xs = F.interpolate(xs, size=(n, n), mode='bilinear', align_corners=False)
        
        # Display the true image
        plt.imshow(xs[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('True Image')
        plt.show()

        # Create clean and noisy observations
        clean_observations = forward_operator(xs)
        noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        
        # Display the noisy observation
        plt.imshow(noisy_observations1[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Noisy Observations')
        plt.show()

        # Initial reconstruction from noisy observations
        inpt = init_recon(noisy_observations1)

        # Initialize objective values for different methods
        obj_pointwise = [objective_function(inpt, noisy_observations1).item()]
        obj_convolution = [objective_function(inpt, noisy_observations1).item()]
        obj_scalar = [objective_function(inpt, noisy_observations1).item()]

        # Clone initial reconstruction for different methods
        inpt_pointwise = inpt.clone()
        inpt_full = inpt.clone()
        inpt_convolution = inpt.clone()
        inpt_scalar = inpt.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_nag = [objective_function(inpt, noisy_observations1).item()]
        inpt_nag = inpt.clone()
        inpt_nag_m1 = inpt.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in range(1001):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_nag + alphat * (inpt_nag - inpt_nag_m1)
            grad_nag = grad_objective(yk, noisy_observations1)
            inpt_nag, inpt_nag_m1 = yk - one_over_L * grad_nag, inpt_nag
            obj_nag.append(objective_function(inpt_nag, noisy_observations1).item())

        # Final approximation
        plt.imshow(inpt_nag[0, :, :, :].squeeze().detach().cpu().numpy(), cmap='gray')
        plt.title('Final Reconstruction - NAG')
        plt.show()
        
        # Plot convergence of NAG
        approx_min = min(obj_nag) 
        iterations = np.arange(0, len(obj_nag))
        marker_every = 10
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min + np.array(obj_nag)), label='Accelerated Gradient Descent', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
        plt.xlabel('Iteration $k$', fontsize=14)
        plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
        plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
        plt.legend(fontsize=12, loc='upper right')
        plt.grid(True, which="both", linestyle='--', linewidth=0.5)
        plt.tight_layout()
        plt.show()

        # Main training loop for preconditioning methods
        for k in tqdm(range(training_iterations)):
            grad_pointwise = grad_objective(inpt_pointwise, noisy_observations1)
            grad_convolution = grad_objective(inpt_convolution, noisy_observations1)
            grad_scalar = grad_objective(inpt_scalar, noisy_observations1)

            
            # Pointwise preconditioner optimization
            if k == 0:
                pointwise_preconditioner = (1 / L_max) * torch.ones((1, 1, n, n)).to(device)
            
            pointwise_iters = 1000
            
            try:
                ## if already learned at this iteration, then use it
                pointwise_preconditioner = dict_pointwise[k]
            except:
                ## else learn it
                pointwise_preconditioner = optimisation_p_pointwise(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_pointwise, pointwise_preconditioner, L, approx_min, max_iter=pointwise_iters, tol=0.001, verbose=True)
            #torch.save(dict_pointwise, 'pointwise_dictionary_blur.pt')

            # Alpha step optimization
            if k == 0:
                alpha_step = torch.tensor((1 / L_max)).to(device)
            try:
                alpha_step = torch.tensor(dict_alpha[k]).to(device)
            except:
                alpha_step = optimisation_alpha(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_scalar, alpha_step, L, approx_min, max_iter=100, verbose=False)
                dict_alpha[k] = alpha_step.item()
                #torch.save(dict_alpha, 'alpha_dictionary_blur.pt')

            # Convolution kernel optimization
            ## set the size of the kernel
            kernel_width = n 
            kernel_height = kernel_width
            if k == 0:
                kernel = torch.zeros((1, 1, kernel_width, kernel_height)).to(device)
                kernel[0, 0, kernel_width // 2, kernel_height // 2] = 1 / L_max
            
            max_iter_conv = 5000

            try:
                kernel = dict_kernel[k]
            except:
                print('calc kernel')
                kernel = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution, kernel, L, approx_min, kernel_width, kernel_height, max_iter=max_iter_conv, tol=0.0001, manual=True, verbose=True)
                dict_kernel[k] = kernel
                #torch.save(dict_kernel, 'kernel_dictionary_blur.pt')
            

            # Update inputs using different methods
            inpt_pointwise = inpt_pointwise - pointwise_preconditioner * grad_pointwise
            inpt_scalar = inpt_scalar - alpha_step * grad_scalar
            inpt_convolution = inpt_convolution - convolution(grad_convolution, kernel)
            
            # Append objective values
            obj_pointwise.append(objective_function(inpt_pointwise, noisy_observations1).item())
            obj_convolution.append(objective_function(inpt_convolution, noisy_observations1).item())
            obj_scalar.append(objective_function(inpt_scalar, noisy_observations1).item())

            iterations = np.arange(0, len(obj_pointwise))
            
            plt.figure(figsize=(10, 8), dpi=300)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_pointwise)), label='Pointwise', linestyle='-.', marker='x', markersize=4, color='red')
            plt.loglog(1+iterations, (-approx_min + np.array(obj_scalar)), label='Scalar', linestyle='-.', marker='x', markersize=4, color='blue')
            plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution)[iterations]), label=f'Convolution - {kernel_height} x {kernel_width}', linestyle='-.', marker='x', markersize=4, color='black')
            plt.loglog(1+iterations, (-approx_min + np.array(obj_nag)[iterations]), label='NAG', linestyle='-.', marker='x', markersize=4, color='yellow')
            plt.xlabel(r'Iteration $t$', fontsize=14)
            plt.ylabel('Optimality in Function Value', fontsize=14)
            plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
            plt.legend(fontsize=12, loc='upper right')
            plt.grid(True, which="both", linestyle='--', linewidth=0.5)
            plt.tight_layout()
            plt.show()
