import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import random_split
from util.operators import get_ct_operator
from util.objective_functions import fast_huber_TV, fast_huber_grad
from util.optimisation_functions import optimisation_p_full
from util.utils import apply_p_full, power_method_square_func
# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define training parameters

dataset_name = 'data/SARS-COV-2_CT_non-COVID-128.pt'
# Define parameters
training_iterations = 1000 # number of iterations over which the optimisation algorithm is learned
noise_level = 0.01
huber_const = 0.01 # epsilon used in definition of huber norm
reg_const = 0.0001 # scalar multiplying the regularizer
N_TRAIN = 1000 # number of images (functions) to train over
N_VAL_REG = 100 # number of images (functions) to calculate regularization parameters \lambda_t over
N_VAL = 100 # number of images (functions) to test over
n = 40 ## side length of image (assumed to be square)
n_angles = 90 ## number of angles used in the radon transform



# Get blurring operator and its properties
forward_operator, adjoint_operator, operator_norm, init_recon, W, W_adj = get_ct_operator(n, n_angles)
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg
print(f'f is L={L} smooth.')


# Load datasets
try:
    try:
        train_dataset = torch.load(f'train_dataset_CT_{N_TRAIN}_valreg.pt')
        val_dataset = torch.load(f'val_dataset_CT_{N_VAL}_valreg.pt')
        print('LOADED DATASETS')
    except:
        dataset = torch.load(dataset_name).to(device)
        dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
        train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
        torch.save(train_dataset, f'train_dataset_CT_{N_TRAIN}_valreg.pt')
        torch.save(val_dataset, f'val_dataset_CT_{N_VAL}_valreg.pt')
except:
    print('so that you dont have to donwload extra data, lets see how it performs on STL data')
    test_train = False
    import torchvision
    from torch.utils.data import random_split
    dataset = torchvision.datasets.STL10('STL', split='test', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=N_TRAIN, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=N_VAL, shuffle=True)


def crop_center(image, crop_size):
    N, c, h, w = image.shape
    start_x = w // 2 - crop_size // 2
    start_y = h // 2 - crop_size // 2
    return image[:,:, start_y:start_y + crop_size, start_x:start_x + crop_size]

def huber_grad(x, gamma):
    return torch.where(torch.abs(x) < gamma, x / gamma, torch.sign(x))

def huber(x, gamma):
    y = torch.where(torch.abs(x) < gamma, 
                    0.5 * x**2 / gamma, 
                    torch.abs(x) - 0.5 * gamma)
    y_flat = y.view(y.size(0), -1)
    loss = torch.mean(torch.sum(y_flat, dim=-1))
    return loss

# Define the overall objective function - mean over batch size
def objective_function(x, y):
    return (0.5 * torch.norm(forward_operator(x) - y) ** 2 + reg_const * fast_huber_TV(x, delta=huber_const)) / x.shape[0]

# Gradient functions for the objective function - without dividning by x.shape[0] - this information is absorbed into L
def grad_objective(x, y):
    return adjoint_operator(forward_operator(x) - y) + reg_const * fast_huber_grad(x, delta=huber_const)


one_over_L = 1 / L

L = [L for _ in range(N_TRAIN)]
L_max = max(L)

constant_lambda = True
lambda_value_constant = 1e-10

## too big to fit in supplementary materials.
dict_full = {}
    

if __name__ == '__main__':


    for i, data in enumerate(val_dataloader):
        # Preprocess the images
        try:
            xs_val = crop_center(data, n)#.double()
        except:
            xs_val = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()
        
        # Create clean and noisy observations
        clean_observations_val = forward_operator(xs_val)
        noisy_observations1_val = clean_observations_val + (noise_level * torch.randn_like(clean_observations_val) * torch.mean(torch.abs(clean_observations_val).view(xs_val.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))


        # Initial reconstruction from noisy observations
        inpt_val = init_recon(noisy_observations1_val)

        obj_val_full = [objective_function(inpt_val, noisy_observations1_val).item()]
        inpt_val_full = inpt_val.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_val_nag = [objective_function(inpt_val, noisy_observations1_val).item()]
        inpt_val_nag = inpt_val.clone()
        inpt_val_nag_m1 = inpt_val.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in tqdm(range(301)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_val_nag + alphat * (inpt_val_nag - inpt_val_nag_m1)
            grad_val_nag = grad_objective(yk, noisy_observations1_val)
            inpt_val_nag, inpt_val_nag_m1 = yk - one_over_L * grad_val_nag, inpt_val_nag
            obj_val_nag.append(objective_function(inpt_val_nag, noisy_observations1_val).item())

        approx_min_val = min(obj_val_nag)
        
    for i, data in enumerate(data_loader):
        # Preprocess the images
        try:
            xs = crop_center(data, n)#.double()
        except:
            xs = crop_center(data[0].mean(dim=1, keepdim=True).to(device), n)#.double()

        # Create clean and noisy observations
        clean_observations = forward_operator(xs)
        #noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.max(clean_observations.view(xs.shape[0], -1), dim=1)[0].unsqueeze(1).unsqueeze(1).unsqueeze(1))
        noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))
        
        #xs = torch.load('true_image_CT0909.pt')
        #noisy_observations1 = torch.load('noisy_observations_CT0909.pt')

        # Initial reconstruction from noisy observations
        inpt = init_recon(noisy_observations1)

        obj_full = [objective_function(inpt, noisy_observations1).item()]
        inpt_full = inpt.clone()

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_nag = [objective_function(inpt, noisy_observations1).item()]
        inpt_nag = inpt.clone()
        inpt_nag_m1 = inpt.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in tqdm(range(301)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_nag + alphat * (inpt_nag - inpt_nag_m1)
            grad_nag = grad_objective(yk, noisy_observations1)
            inpt_nag, inpt_nag_m1 = yk - one_over_L * grad_nag, inpt_nag
            obj_nag.append(objective_function(inpt_nag, noisy_observations1).item())

        # Plot convergence of NAG
        approx_min = min(obj_nag)
        iterations = np.arange(0, len(obj_nag))
        marker_every = 10
        plt.figure(figsize=(10, 8), dpi=300)
        plt.semilogy(iterations, (-approx_min + np.array(obj_nag)), label='Accelerated Gradient Descent', linestyle='-.', marker='x', markersize=4, markevery=marker_every, color='yellow')
        plt.xlabel('Iteration $k$', fontsize=14)
        plt.ylabel(f'$f(x^k) - f(x^*)$ Averaged Over Dataset of size {N_TRAIN}', fontsize=14)
        plt.title('Training Convergence Plots for Image Deblurring', fontsize=16)
        plt.legend(fontsize=12, loc='upper right')
        plt.grid(True, which="both", linestyle='--', linewidth=0.5)
        plt.tight_layout()
        plt.show()
    
        # Main training loop for preconditioning methods
        for k in tqdm(range(200)):
            
            grad_full = grad_objective(inpt_full, noisy_observations1)

            if len(dict_full) == 0:
                dict_full[0] = (1/L_max) * torch.eye(n ** 2).to(device)

            if k == 0:
                num_full_iters = 1
                p_full = optimisation_p_full(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_full, dict_full[0], L, approx_min, max_iter=num_full_iters, tol=1e-6, lmbda=lambda_value_constant)
            else:
                try:
                    p_full = dict_full[k]
                except:
                    num_full_iters = 5000
                    p_full = optimisation_p_full(lambda x: objective_function(x, noisy_observations1), lambda x: grad_objective(x, noisy_observations1), inpt_full, (1/L_max) * torch.eye(n ** 2).to(device), L, approx_min, max_iter=num_full_iters, tol=1e-6, lmbda=lambda_value_constant)


            ## now evalute the new learned full preconditioner
            inpt_full = inpt_full - apply_p_full(p_full,grad_full)
            print(f'Training Objective value for full preconditioning at iteration {k}: {objective_function(inpt_full, noisy_observations1)}')
            
            grad_val_full = grad_objective(inpt_val_full, noisy_observations1_val)
            inpt_val_full = inpt_val_full - apply_p_full(p_full,grad_val_full)
            print(f'Validation Objective value for full preconditioning at iteration {k}: {objective_function(inpt_val_full, noisy_observations1_val)}')
            
            pwr_diff = power_method_square_func(lambda x: apply_p_full(p_full, x) - x / L_max, n=n) - 1 / L_max
            print(f'POWER DIFF: {pwr_diff}, IF <0 THEN PROVABLY CONVERGENT!')
            
            dict_full[k] = p_full
            #torch.save(dict_full, f'full_dictionary_regularised_{lambda_value_constant}_3.pt')
            
            obj_val_full.append(objective_function(inpt_val_full, noisy_observations1_val).item())
            obj_full.append(objective_function(inpt_full, noisy_observations1).item())
            
            
            iterations = np.arange(0, len(obj_full))
            objective_difference_iter0_val = objective_function(inpt_val, noisy_observations1_val).item() - approx_min_val
            objective_difference_iter0 = objective_function(inpt, noisy_observations1).item() - approx_min
            plt.figure(figsize=(10, 8), dpi=300)
            plt.loglog(1+iterations, (-approx_min + np.array(obj_full)[iterations])/objective_difference_iter0, label='Learned Full Operator - Train', linestyle='-', color='purple', linewidth=2)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_full)[iterations])/objective_difference_iter0_val, label='Learned Full Operator - Test', linestyle='--', color='purple', linewidth=4)
            plt.xlabel(r'Iteration $t$', fontsize=24)
            plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=30)
            #plt.xlabel('Iteration', fontsize=24)
            #plt.ylabel('Optimality in Function Value', fontsize=30)
            plt.tick_params(axis='both', which='major', labelsize=24)
            plt.ylim(1e-9, 1)  # Set lower limit to 10^-7
            #plt.title('Training Convergence Plots for MNIST Logistic Regression', fontsize=16)
            plt.legend(fontsize=24, loc='lower left')
            plt.grid(True, linestyle='-', linewidth=1)
            plt.tight_layout()
            #plt.savefig(f'full_loglog_convergence_plot_m9_{lambda_value_constant}.png', format='png', dpi=300)
            plt.show()
            
            
            if pwr_diff < 0:
                print('PROVABLY CONVERGENT!')
                break