import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import random_split
from util.operators import get_ct_operator
from util.objective_functions import fast_huber_TV, fast_huber_grad
from util.optimisation_functions import optimisation_p_full
from util.utils import apply_p_full
# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

dataset_name = 'data/SARS-COV-2_CT_non-COVID-128.pt'

# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.01
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 0.0001 # scalar multiplying the regularizer
N_TRAIN = 1000 # number of images (functions) to train over
N_VAL = 100 # number of images (functions) to test over
n = 40 ## side length of image (assumed to be square)
n_angles = 90 ## number of angles used in the radon transform

# Get blurring operator and its properties
forward_operator, adjoint_operator, operator_norm, init_recon, W, W_adj = get_ct_operator(n, n_angles)
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')

# Load datasets
try:
    try:
        train_dataset = torch.load(f'train_dataset_CT_{N_TRAIN}.pt')
        val_dataset = torch.load(f'val_dataset_CT_{N_VAL}.pt')
        print('LOADED DATASETS')
    except:
        # Load and prepare the dataset
        dataset = torch.load(dataset_name).to(device)
        dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
        train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
        torch.save(train_dataset, f'train_dataset_CT_{N_TRAIN}.pt')
        torch.save(val_dataset, f'val_dataset_CT_{N_VAL}.pt')
except:
    print('so that you dont have to donwload extra data, lets see how it performs on STL data')
    test_train = False
    import torchvision
    from torch.utils.data import random_split
    dataset = torchvision.datasets.STL10('STL', split='test', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])

# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)

def crop_center(image, crop_size):
    N, c, h, w = image.shape
    start_x = w // 2 - crop_size // 2
    start_y = h // 2 - crop_size // 2
    return image[:,:, start_y:start_y + crop_size, start_x:start_x + crop_size]

# Define the overall objective function - mean over batch size
def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]

# Gradient functions for the objective function - without dividning by x.shape[0] - this information is absorbed into L
def grad_objective(x, y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)


one_over_L = 1 / L

L = [L for _ in range(N_TRAIN)]
L_max = max(L)

dict_full = {}#torch.load('full_dictionary_CT0609.pt')


if __name__ == '__main__':
    

    for i, data in enumerate(val_dataloader):
        # Preprocess the images
        try:
            xs_val = crop_center(data, n)#.double()
        except:
            xs_val = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()
        
        # Display the true image
        plt.imshow(xs_val[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('True Image')
        plt.show()

        # Create clean and noisy observations
        clean_observations_val = forward_operator(xs_val)
        noisy_observations1_val = clean_observations_val + (noise_level * torch.randn_like(clean_observations_val) * torch.mean(torch.abs(clean_observations_val).view(xs_val.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        
        # Display the noisy observations
        plt.imshow(noisy_observations1_val[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Noisy Observations')
        plt.show()

        # Initial reconstruction from noisy observations
        inpt_val = init_recon(noisy_observations1_val)

        # Display the initial reconstruction
        plt.imshow(inpt_val[0, :, :, :].squeeze().cpu().numpy(), cmap='gray')
        plt.title('Initial Reconstruction')
        plt.show()

        obj_val_full = [objective_function(inpt_val, noisy_observations1_val).item()]
        inpt_val_full = inpt_val.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_val_nag = [objective_function(inpt_val, noisy_observations1_val).item()]
        inpt_val_nag = inpt_val.clone()
        inpt_val_nag_m1 = inpt_val.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in tqdm(range(301)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_val_nag + alphat * (inpt_val_nag - inpt_val_nag_m1)
            grad_val_nag = grad_objective(yk, noisy_observations1_val)
            inpt_val_nag, inpt_val_nag_m1 = yk - one_over_L * grad_val_nag, inpt_val_nag
            obj_val_nag.append(objective_function(inpt_val_nag, noisy_observations1_val).item())

        approx_min_val = min(obj_val_nag)
        
    for i, data in enumerate(data_loader):
        # Preprocess the images
        try:
            xs = crop_center(data, n)#.double()
        except:
            xs = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()

        # Create clean and noisy observations
        clean_observations = forward_operator(xs)
        noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))

        # Initial reconstruction from noisy observations
        inpt = init_recon(noisy_observations1)

        obj_full = [objective_function(inpt, noisy_observations1).item()]
        inpt_full = inpt.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_nag = [objective_function(inpt, noisy_observations1).item()]
        inpt_nag = inpt.clone()
        inpt_nag_m1 = inpt.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in tqdm(range(301)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_nag + alphat * (inpt_nag - inpt_nag_m1)
            grad_nag = grad_objective(yk, noisy_observations1)
            inpt_nag, inpt_nag_m1 = yk - one_over_L * grad_nag, inpt_nag
            obj_nag.append(objective_function(inpt_nag, noisy_observations1).item())

        # Plot convergence of NAG
        approx_min = min(obj_nag)
        iterations = np.arange(0, len(obj_nag))
        marker_every = 10
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min + np.array(obj_nag)), label='Accelerated Gradient Descent', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
        plt.xlabel('Iteration $k$', fontsize=14)
        plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
        plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
        plt.legend(fontsize=12, loc='upper right')
        plt.grid(True, which="both", linestyle='--', linewidth=0.5)
        plt.tight_layout()
        plt.show()

        # Main training loop for preconditioning methods
        train_more = True
        for k in tqdm(range(25)):
            
            grad_full = grad_objective(inpt_full, noisy_observations1)
            
            if len(dict_full) == 0:
                dict_full[0] = (1/L_max) * torch.eye(n ** 2).to(device)
            
            if train_more:
                max_iter_full = 5000
                
                try:
                    p_full = optimisation_p_full(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_full, dict_full[k], L, approx_min, max_iter=max_iter_full, tol=1e-3)
                except:
                    p_full = optimisation_p_full(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_full, dict_full[k-1], L, approx_min, max_iter=max_iter_full, tol=1e-3)

                dict_full[k] = p_full
                #torch.save(dict_full, 'full_dictionary_CT0609.pt')
            else:
                p_full = dict_full[k]

            inpt_full = inpt_full - apply_p_full(p_full,grad_full)
            print(f'Training Objective value for full preconditioning at iteration {k}: {objective_function(inpt_full, noisy_observations1)}')
            obj_full.append(objective_function(inpt_full, noisy_observations1).item())
            
            grad_val_full = grad_objective(inpt_val_full, noisy_observations1_val)
            inpt_val_full = inpt_val_full - apply_p_full(p_full,grad_val_full)
            print(f'Validation Objective value for full preconditioning at iteration {k}: {objective_function(inpt_val_full, noisy_observations1_val)}')
            obj_val_full.append(objective_function(inpt_val_full, noisy_observations1_val).item())
            
            
            ## save obj_full:
            #torch.save(obj_full, 'obj_full_2609.pt')
            #torch.save((-approx_min + np.array(obj_full))/(-approx_min + obj_full[0]), 'obj_val_norm_2609.pt')
            
            iterations = np.arange(0, len(obj_full))
            objective_difference_iter0_val = objective_function(inpt_val, noisy_observations1_val).item() - approx_min_val
            objective_difference_iter0 = objective_function(inpt, noisy_observations1).item() - approx_min
            plt.figure(figsize=(10, 8), dpi=300)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_full)[iterations])/objective_difference_iter0, label='Learned Full Operator - Train', linestyle='-', color='purple', linewidth=2)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_full)[iterations])/objective_difference_iter0_val, label='Learned Full Operator - Test', linestyle='--', color='purple', linewidth=4)
            plt.xlabel('Iteration', fontsize=24)
            plt.ylabel('Optimality in Function Value', fontsize=30)
            plt.tick_params(axis='both', which='major', labelsize=24)
            plt.ylim(1e-10, 1)  
            plt.legend(fontsize=24, loc='lower left')
            plt.grid(True, linestyle='-', linewidth=1)
            plt.tight_layout()
            #plt.savefig('full_loglog_convergence_plot_m9.png', format='png', dpi=300)
            plt.show()