import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import random_split
from util.objective_functions import fast_huber_TV, fast_huber_grad
from util.optimisation_functions import convolution
import time
import torchvision
# Use GPU if available, else fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

same_params = False

# Define training parameters
strong_conv_constant = 0
noise_level = 0.0025
huber_const = 0.01
reg_const = 1e-05
N_TRAIN = 100
N_VAL = 100
n = 96
training_iterations = 20
batch_size = 8  # Define the batch size

# Gaussian kernel function
def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
    x = torch.arange(size).float() - size // 2
    x = x.view(-1, 1)
    y = torch.arange(size).float() - size // 2
    y = y.view(1, -1)
    kernel_2d = torch.exp(-0.5 * (x**2 + y**2) / sigma**2)
    kernel_2d = kernel_2d / kernel_2d.sum()
    return kernel_2d.unsqueeze(0).unsqueeze(0)

size = 5
sigma = 1.5
gaussian_kernel = gaussian_kernel(size, sigma).to(device)

def forward_operator(x):
    return convolution(x, gaussian_kernel)

adjoint_operator = forward_operator

def init_recon(x):
    return x

operator_norm = 1.0
L_smooth_reg = 8 * reg_const / huber_const
L = operator_norm ** 2 + L_smooth_reg + strong_conv_constant
print(f'f is L={L} smooth.')

try:
    train_dataset = torch.load(f'train_dataset_blur_{N_TRAIN}.pt')
    val_dataset = torch.load(f'val_dataset_blur_{N_VAL}.pt')
    print('Loaded previous dataset')
except:
    dataset = torchvision.datasets.STL10('STL', split='test', transform=torchvision.transforms.ToTensor(), folds=1, download=True)
    dataset = torch.utils.data.Subset(dataset, list(range(N_TRAIN + N_VAL)))
    train_dataset, val_dataset = random_split(dataset, [N_TRAIN, N_VAL])
    torch.save(train_dataset, f'train_dataset_blur_{N_TRAIN}.pt')
    torch.save(val_dataset, f'val_dataset_blur_{N_VAL}.pt')

# Create data loaders for training and validation
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True)

# Define loss and objective functions
def data_fit(x, y):
    return 0.5 * torch.norm(forward_operator(x) - y) ** 2 / x.shape[0]

def reg_func(x):
    return (reg_const * fast_huber_TV(x, delta=huber_const) + 0.5 * strong_conv_constant * torch.norm(x) ** 2) / x.shape[0]

def objective_function(x, y):
    return data_fit(x, y) + reg_func(x)

# Define gradient functions
def grad_reg_func(x):
    return (reg_const * fast_huber_grad(x, delta=huber_const) + strong_conv_constant * x)

def grad_data_fit(x, y):
    return adjoint_operator(forward_operator(x) - y)

def grad_objective(x, y):
    return grad_data_fit(x, y) + grad_reg_func(x)

# Initialize preconditioner and other variables
mu = strong_conv_constant
two_over_mu_plus_L = 2 / (mu + L)
one_over_L = 1 / L

dict_kernel = torch.load('learned_operators/kernel_dictionary_blur.pt')

L = [L] * batch_size  # Adjust L based on batch size
L_max = L[0]

# Training loop using mini-batches
if __name__ == '__main__':
    for i, data in enumerate(val_dataloader):
        # Preprocess the images
        xs_val, labels = data
        xs_val = xs_val.mean(dim=1, keepdim=True).to(device)
        
        # Create clean and noisy observations
        clean_observations_val = forward_operator(xs_val)
        noisy_observations1_val = clean_observations_val + (noise_level * torch.randn_like(clean_observations_val) * torch.mean(torch.abs(clean_observations_val).view(xs_val.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))


        # Initial reconstruction from noisy observations
        inpt_val = init_recon(noisy_observations1_val)

        # Variables for NAG method
        tnew = 0
        told = 0
        obj_val_nag = [objective_function(inpt_val, noisy_observations1_val).item()]
        inpt_val_nag = inpt_val.clone()
        inpt_val_nag_m1 = inpt_val.clone()

        # Accelerated Gradient Descent (NAG) method
        for k in tqdm(range(301)):
            tnew, told = (1 + np.sqrt(1 + 4 * told ** 2)) / 2, tnew
            alphat = (told - 1) / tnew
            yk = inpt_val_nag + alphat * (inpt_val_nag - inpt_val_nag_m1)
            grad_val_nag = grad_objective(yk, noisy_observations1_val)
            inpt_val_nag, inpt_val_nag_m1 = yk - one_over_L * grad_val_nag, inpt_val_nag
            obj_val_nag.append(objective_function(inpt_val_nag, noisy_observations1_val).item())

        approx_min_val = min(obj_val_nag)
        objective_difference_iter0_val = -approx_min_val + objective_function(inpt_val, noisy_observations1_val).item()
        
    n_epochs = 20000

    kernel_width = n
    kernel_height = kernel_width
    kernel = torch.zeros((1, 1, kernel_height, kernel_width)).to(device)
    kernel[0, 0, n // 2 - 1, n // 2 - 1] = 1 / L_max
    try:
        start_epoch = 13303
        kernels_training = torch.load(f'learned_operators/kernel_dictionary_epoch_{start_epoch}_{same_params}_scratch_2809.pt')
        print('Loaded previous kernel dictionary')
    except:
        kernels_training = [kernel.clone().detach().requires_grad_(True) for _ in range(training_iterations)]
    loss_iters_kernel = [0]*n_epochs
    
    optimizer_kernel = torch.optim.Adam(kernels_training, lr=0.00001)
    start_time = time.time()
    
    for epoch in tqdm(range(n_epochs)):

        for batch in data_loader:
            xs, labels = batch
            
            xs = xs.mean(dim=1, keepdim=True).to(device)
            xs = F.interpolate(xs, size=(n, n), mode='bilinear', align_corners=False)

            # Create clean and noisy observations
            clean_observations = forward_operator(xs)
            noisy_observations1 = clean_observations + (noise_level * torch.randn_like(clean_observations) * torch.mean(torch.abs(clean_observations).view(xs.shape[0], -1), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1))

            inpt = init_recon(noisy_observations1)

            # Initialize objective values for different methods
            obj_convolution = [objective_function(inpt, noisy_observations1).item()]
            obj_kernel_greedy = [objective_function(inpt, noisy_observations1).item()]

            # Clone initial reconstruction for different methods
            inpt_convolution = inpt.clone()
            inpt_kernel_greedy = inpt.clone()

            loss_convolution = 0
            
            # Main training loop for preconditioning methods
            for k in range(training_iterations):
                grad_convolution = grad_objective(inpt_convolution, noisy_observations1)
                grad_kernel_greedy = grad_objective(inpt_kernel_greedy, noisy_observations1)

                inpt_convolution = inpt_convolution - convolution(grad_convolution, kernels_training[k])
                inpt_kernel_greedy = inpt_kernel_greedy - convolution(grad_kernel_greedy, dict_kernel[k])
                
                # Append objective values
                obj_convolution.append(objective_function(inpt_convolution, noisy_observations1).item())
                obj_kernel_greedy.append(objective_function(inpt_kernel_greedy, noisy_observations1).item())

                loss_convolution += objective_function(inpt_convolution, noisy_observations1)
            
            ## update parameters
            optimizer_kernel.zero_grad()
            loss_convolution.backward()
            optimizer_kernel.step()
            
            iterations = np.arange(0, len(obj_convolution))

            loss_iters_kernel[epoch] += loss_convolution.item()/N_TRAIN
        
        elapsed_time_hours = (time.time() - start_time) / 3600
        
        if epoch % 100 == 1:
            
            obj_val_conv = [objective_function(inpt_val, noisy_observations1_val).item()]
            inpt_val_conv = inpt_val.clone()
            obj_val_conv_unroll = [objective_function(inpt_val, noisy_observations1_val).item()]
            inpt_val_conv_unroll = inpt_val.clone()
            for k in range(20):
                grad_convolution_val = grad_objective(inpt_val_conv_unroll, noisy_observations1_val)
                grad_kernel_greedy_val = grad_objective(inpt_val_conv, noisy_observations1_val)

                inpt_val_conv_unroll = inpt_val_conv_unroll - convolution(grad_convolution_val, kernels_training[k])
                inpt_val_conv = inpt_val_conv - convolution(grad_kernel_greedy_val, dict_kernel[k])
                
                # Append objective values
                obj_val_conv_unroll.append(objective_function(inpt_val_conv_unroll, noisy_observations1_val).item())
                obj_val_conv.append(objective_function(inpt_val_conv, noisy_observations1_val).item())


            iterations = np.arange(0, 20)
            plt.figure(figsize=(10, 8), dpi=300)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_conv)[iterations])/objective_difference_iter0_val, label='Learned Convolution Greedy', linestyle='-', color='black', linewidth=4)
            plt.loglog(1+iterations, (-approx_min_val + np.array(obj_val_conv_unroll)[iterations])/objective_difference_iter0_val, label='Learned Convolution Unrolling', linestyle='-', color='cyan', linewidth=4)
            plt.xlabel(r'Iteration $t$', fontsize=28)
            plt.ylabel(r'$\frac{F(x_t) - F(x^*)}{F(x_0) - F(x^*)}$', fontsize=32)
            plt.legend(fontsize=28, loc='upper right')
            plt.tick_params(axis='both', which='major', labelsize=28)
            plt.grid(True, linestyle='-', linewidth=1)
            plt.tight_layout()
            plt.show()
            
            plt.figure(figsize=(10,8), dpi=300)
            plt.semilogy(loss_iters_kernel[:epoch+1], linestyle='-', linewidth=4, color='black')
            plt.xlabel(r'Epoch', fontsize=30)
            plt.ylabel(r'Training Loss', fontsize=30)
            plt.grid(True, which="both", linestyle='--', linewidth=0.5)
            plt.tight_layout()
            plt.tick_params(axis='both', which='major', labelsize=30)
            #plt.savefig(f'losses_epoch_{epoch}_{same_params}.png')
            plt.show()

            print(f"Elapsed time: {(elapsed_time_hours):.2f} hours")
                
            #torch.save(kernels_training, f'kernel_dictionary_epoch_{epoch}_{same_params}_scratch_2809.pt')



