import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import random_split
from util.operators import get_ct_operator
from util.objective_functions import fast_huber_TV, fast_huber_grad
from util.optimisation_functions import optimisation_p_pointwise, optimisation_alpha, optimisation_convolution, convolution
# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

dataset_name = 'data/SARS-COV-2_CT_non-COVID-128.pt'

# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.01
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 0.0001 # scalar multiplying the regularizer
N_TRAIN = 100 # number of images (functions) to train over
N_VAL = 100 # number of images (functions) to test over
n = 40 ## side length of image (assumed to be square)
n_angles = 90 ## number of angles used in the radon transform

# Get blurring operator and its properties
forward_operator, adjoint_operator, operator_norm, init_recon, W, W_adj = get_ct_operator(n, n_angles)
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')


# Load datasets
try:
    try:
        train_dataset = torch.load('train_dataset_CT.pt')
        val_dataset = torch.load('val_dataset_CT.pt')
    except:
        dataset = torch.load(dataset_name).to(device)
        dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
        train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
        torch.save(train_dataset, 'train_dataset_CT.pt')
        torch.save(val_dataset, 'val_dataset_CT.pt')
except:
    print('so that you dont have to donwload extra data, lets see how it performs on STL data')
    test_train = False
    import torchvision
    from torch.utils.data import random_split
    dataset = torchvision.datasets.STL10('STL', split='test', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)

def crop_center(image, crop_size):
    N, c, h, w = image.shape
    start_x = w // 2 - crop_size // 2
    start_y = h // 2 - crop_size // 2
    return image[:,:, start_y:start_y + crop_size, start_x:start_x + crop_size]

# Define the overall objective function - mean over batch size
def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]

# Gradient functions for the objective function - without dividning by x.shape[0] - this information is absorbed into L
def grad_objective(x, y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)

one_over_L = 1 / L

L = [L for _ in range(N_TRAIN)]
L_max = max(L)

try:
    dict_pointwise = torch.load('learned_operators/pointwise_dictionary_CT.pt')
    dict_kernel = torch.load('learned_operators/kernel_dictionary_CT.pt')
    dict_alpha = torch.load('learned_operators/alpha_dictionary_CT.pt')
except:
    dict_pointwise = {}
    dict_kernel = {}
    dict_alpha = {}

FINAL_TRAIN_ITER = len(dict_kernel)

print('final training iteration:', FINAL_TRAIN_ITER)

if __name__ == '__main__':  
    for i, data in enumerate(data_loader):
        # Preprocess the images
        try:
            xs = crop_center(data, n)#.double()
        except:
            xs = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()
        
        # Display the true image
        plt.imshow(xs[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('True Image')
        plt.show()

        # Create clean and noisy observations
        clean_observations = forward_operator(xs)
        noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        
        # Display the noisy observation
        plt.imshow(noisy_observations1[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Noisy Observations')
        plt.show()

        # Initial reconstruction from noisy observations
        inpt = init_recon(noisy_observations1)

        # Display the initial reconstruction
        plt.imshow(inpt[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Initial Reconstruction')
        plt.show()

        # Initialize objective values for different methods
        obj_pointwise = [objective_function(inpt, noisy_observations1).item()]
        obj_convolution = [objective_function(inpt, noisy_observations1).item()]
        obj_scalar = [objective_function(inpt, noisy_observations1).item()]

        # Clone initial reconstruction for different methods
        inpt_pointwise = inpt.clone()
        inpt_convolution = inpt.clone()
        inpt_scalar = inpt.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_nag = [objective_function(inpt, noisy_observations1).item()]
        inpt_nag = inpt.clone()
        inpt_nag_m1 = inpt.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in tqdm(range(301)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_nag + alphat * (inpt_nag - inpt_nag_m1)
            grad_nag = grad_objective(yk, noisy_observations1)
            inpt_nag, inpt_nag_m1 = yk - one_over_L * grad_nag, inpt_nag
            obj_nag.append(objective_function(inpt_nag, noisy_observations1).item())

        # Plot convergence of NAG
        approx_min = min(obj_nag)
        iterations = np.arange(0, len(obj_nag))
        marker_every = 10
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min + np.array(obj_nag)), label='Accelerated Gradient Descent', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
        plt.xlabel('Iteration $k$', fontsize=14)
        plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
        plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
        plt.legend(fontsize=12, loc='upper right')
        plt.grid(True, which="both", linestyle='--', linewidth=0.5)
        plt.tight_layout()
        plt.show()

        # Main training loop for preconditioning methods
        for k in tqdm(range(training_iterations)):
            grad_pointwise = grad_objective(inpt_pointwise, noisy_observations1)
            grad_convolution = grad_objective(inpt_convolution, noisy_observations1)
            grad_scalar = grad_objective(inpt_scalar, noisy_observations1)

            # Pointwise preconditioner optimization
            if k == 0:
                pointwise_preconditioner = (1 / L_max) * torch.ones((1, 1, n, n)).to(device)
            
            try:
                pointwise_preconditioner = dict_pointwise[k]
            except:
                pointwise_preconditioner = optimisation_p_pointwise(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_pointwise, pointwise_preconditioner, L, approx_min, max_iter=1000, tol=0.001)
                dict_pointwise[k] = pointwise_preconditioner
                #torch.save(dict_pointwise, 'pointwise_dictionary_CT0308-2.pt')

            # Alpha step optimization
            if k == 0:
                alpha_step = torch.tensor((1 / L_max )).to(device)

            try:
                alpha_step = torch.tensor(dict_alpha[k]).to(device)
            except:
                alpha_step = optimisation_alpha(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_scalar, alpha_step, L, approx_min, max_iter=1000)
                dict_alpha[k] = alpha_step.item()
                #torch.save(dict_alpha, 'alpha_dictionary_CT0308-2.pt')

            # Convolution kernel optimization
            kernel_width = n
            kernel_height = kernel_width
            if kernel_height % 2 == 0:
                center = kernel_height//2-1
            else:
                center = kernel_height//2
            if k == 0:
                kernel = torch.zeros((1, 1, kernel_width, kernel_height)).to(device)
                kernel[0, 0, center, center] = 1 / L_max 
            
            max_iter_conv = 5000
            try:
                kernel = dict_kernel[k]
            except:
                kernel = optimisation_convolution(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_convolution, kernel, L, approx_min, kernel_width, kernel_height, max_iter=max_iter_conv, tol=0.001, manual=True)
                dict_kernel[k] = kernel
                #torch.save(dict_kernel, 'kernel_dictionary_CT0308-2.pt')


            # Update inputs using different methods
            inpt_pointwise = inpt_pointwise - pointwise_preconditioner * grad_pointwise
            inpt_scalar = inpt_scalar - alpha_step * grad_scalar
            inpt_convolution = inpt_convolution - convolution(grad_convolution, kernel)
            
            # Append objective values
            obj_pointwise.append(objective_function(inpt_pointwise, noisy_observations1).item())
            obj_convolution.append(objective_function(inpt_convolution, noisy_observations1).item())
            obj_scalar.append(objective_function(inpt_scalar, noisy_observations1).item())

            if k > FINAL_TRAIN_ITER:
                plt.plot(list(dict_alpha.values()))
                plt.title('Alpha Values')
                plt.xlabel('Iteration')
                plt.ylabel('Alpha')
                plt.show()

                # Display pointwise preconditioner and kernel
                plt.imshow(pointwise_preconditioner.squeeze().detach().cpu().numpy(), cmap='seismic')
                plt.colorbar()
                plt.title(f'Pointwise Iteration {k}')
                plt.show()

                plt.imshow(kernel.squeeze().detach().cpu().numpy(), cmap='seismic')
                plt.colorbar()
                plt.title(f'Kernel Iteration {k}')
                plt.show()
                # Plot convergence of different methods
                iterations = np.arange(0, len(obj_pointwise))
                marker_every = 10
                plt.figure(figsize=(10, 8), dpi=300)
                plt.semilogy(iterations, (-approx_min + np.array(obj_pointwise)), label='Pointwise', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='red')
                plt.semilogy(iterations, (-approx_min + np.array(obj_scalar)), label='Scalar', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='blue')
                plt.semilogy(iterations, (-approx_min + np.array(obj_convolution)[iterations]), label=f'Convolution - {kernel_height} x {kernel_width}', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='black')
                plt.semilogy(iterations, (-approx_min + np.array(obj_nag)[iterations]), label='NAG', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
                plt.xlabel('Iteration', fontsize=14)
                plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
                plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
                plt.legend(fontsize=12, loc='upper right')
                plt.grid(True, which="both", linestyle='--', linewidth=0.5)
                plt.tight_layout()
                plt.show()
                
                plt.figure(figsize=(10, 8), dpi=300)
                plt.loglog(1+iterations, (-approx_min + np.array(obj_pointwise)), label='Pointwise', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='red')
                plt.loglog(1+iterations, (-approx_min + np.array(obj_scalar)), label='Scalar', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='blue')
                plt.loglog(1+iterations, (-approx_min + np.array(obj_convolution)[iterations]), label=f'Convolution - {kernel_height} x {kernel_width}', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='black')
                plt.loglog(1+iterations, (-approx_min + np.array(obj_nag)[iterations]), label='NAG', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
                #plt.loglog(1+iterations, (-approx_min + np.array(obj_full)), label='Full', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='green')
                plt.xlabel('Iteration', fontsize=14)
                plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
                plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
                plt.legend(fontsize=12, loc='upper right')
                plt.grid(True, which="both", linestyle='--', linewidth=0.5)
                plt.tight_layout()
                plt.show()