import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader, Subset
from tqdm import tqdm
import copy
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

custom_colors = ['darkorange', '#ba2406', '#5B478B', '#356792']
plt.rcParams.update({
    'font.family': 'serif',
    'text.usetex': True,
    'font.size': 1,
    'axes.titlesize': 21,
    'axes.labelsize': 21,
    'legend.fontsize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'figure.dpi': 400,
    'figure.figsize': [6, 4],      
    'axes.linewidth': 0.8,
    'lines.linewidth': 2.0,             
    'lines.markersize': 5,            
    'axes.grid': True,
    'grid.linestyle': '--',
    'grid.linewidth': 0.15,
    'legend.frameon': True,
    'legend.framealpha': 0.75,
    'legend.loc': 'best',
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.02,
    'axes.spines.top': True,
    'axes.spines.right': True,
    'axes.prop_cycle': plt.cycler('color', custom_colors)
})

sigmoid = nn.Sigmoid()
logsigmoid = nn.LogSigmoid()
num_points = 5000
d = 200

def projection(q_xi, q_Xi, U, D):
    q_cov = q_Xi - q_xi * q_xi
    q_xi_new = q_xi.clamp(min=-U, max=U)
    q_cov_new = q_cov.clamp(min=1/D, max=D)
    return (q_xi_new, q_cov_new + q_xi_new * q_xi_new)

def prepare_dataset(X, y, batch):
    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=batch, shuffle=True)
    return data_loader

def log_prob_logistic(X, y, samples):
    w = samples
    logits = X @ w.t()
    logits = logits.t()
    return logsigmoid(y.unsqueeze(0) * logits).sum(dim=1)

def elbo_gd(X, y, q_mu, q_cov_sq, prior, m):
    q_cov = q_cov_sq * q_cov_sq
    q_dist = dist.MultivariateNormal(q_mu, torch.diag_embed(q_cov))
    kl = dist.kl_divergence(q_dist, prior)
        
    eps = torch.randn((d, m), device=device)
    samples = q_mu.unsqueeze(1) + torch.diag_embed(q_cov_sq) @ eps
    log_probs = log_prob_logistic(X, y, samples.t())
        
    log_likelihood = (log_probs.mean()) * num_points / y.shape[0]
    
    entropy = q_dist.entropy()
        
    elbo = log_likelihood - kl
    free_energy = -log_likelihood + kl + entropy
    return elbo, kl, log_likelihood, free_energy

def elbo_ngd(X, y, q_xi, q_Xi, prior, m):
    q_cov = q_Xi - q_xi * q_xi
    q_dist = dist.MultivariateNormal(q_xi, torch.diag_embed(q_cov))
    kl = dist.kl_divergence(q_dist, prior)

    q_cov_sq = torch.sqrt(q_cov)
        
    eps = torch.randn((d, m), device=device)
    samples = q_xi.unsqueeze(1) + torch.diag_embed(q_cov_sq) @ eps
    log_probs = log_prob_logistic(X, y, samples.t())
    log_likelihood = (log_probs.mean()) * num_points / y.shape[0]

    elbo = log_likelihood - kl
    return elbo, kl, log_likelihood, samples
        
def true_parameter_gd(data, prior, m, q_mu, q_cov_sq, num_iter, lr, method, id):
    lr0 = lr
    optimizer = optim.SGD([q_mu, q_cov_sq], lr=lr)
    elbo_history = []
    for iteration in tqdm(range(1, 1 + num_iter)):
        for X, y in data:
            X, y = X.to(device), y.to(device)
            curr_lr = lr0 / np.sqrt(iteration)
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr
            optimizer.zero_grad()
            elbo, kl, log_likelihood, free_energy = elbo_gd(X, y, q_mu, q_cov_sq, prior, m)
            elbo_history.append(elbo.item())
            if method is not None:
                if method[0] == 'proj':
                    loss = -elbo
                if method[0] == 'prox':
                    loss = free_energy
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                if method is not None:
                    if method[0] == 'proj':
                        q_cov_sq.clamp_(min=1 / torch.sqrt(method[1]))
                    if method[0] == 'prox':
                        current_lr = optimizer.param_groups[0]['lr']
                        q_cov_sq += (torch.sqrt(q_cov_sq * q_cov_sq + 4 * current_lr) - q_cov_sq) / 2
        
        if iteration % 10 == 0:
            print(f"Iteration {iteration}: ELBO = {elbo.item():.4f}, KL = {kl.item():.4f}, Log-Likelihood = {log_likelihood.item():.4f}")
            
    torch.save({'mu': q_mu.detach().cpu(), 'cov': (q_cov_sq * q_cov_sq).detach().cpu(), 'elbo_history': elbo_history}, 
               f'model/cifar_gd_d={d}_lr={lr0}_method={method[0]}_id={id}.pt')
    return elbo_history, q_mu, q_cov_sq * q_cov_sq
    
def grad_ngd_ad(log_likelihood, q_xi, q_Xi):
    return torch.autograd.grad(log_likelihood, [q_xi, q_Xi])

def grad_ngd_bonnet(X, y, samples, q_xi):
    logits = (X @ samples).t()
    sig = sigmoid(-y.unsqueeze(0) * logits)
    sig_mean = sig.mean(dim=0)
    temp = y * sig_mean
    gd1 = X.t() @ temp
    
    val_mean = (sig * (1 - sig)).mean(dim=0)
    gd2 = -(X * X).t() @ val_mean
    
    gd_xi = gd1 - gd2 * q_xi
    gd_Xi = gd2 / 2
    gd_xi *= num_points / y.shape[0]
    gd_Xi *= num_points / y.shape[0]
    return [gd_xi, gd_Xi]
    
def true_parameter_ngd(data, prior, m, q_mu, q_cov_sq, num_iter, lr, proj, id):
    lr0 = lr
    eta_p_Lambda = -torch.ones(d, device=device) / 2
    eta_Lambda = -1 / (q_cov_sq * q_cov_sq) / 2
    eta_lambda = -2 * eta_Lambda * q_mu
    elbo_history = []
    q_xi = q_mu
    q_Xi = q_cov_sq * q_cov_sq + q_mu * q_mu
    for iteration in tqdm(range(1, 1 + num_iter)):
        for X, y in data:
            X, y = X.to(device), y.to(device)
            lr = lr0 / np.sqrt(iteration)
            elbo, kl, log_likelihood, samples = elbo_ngd(X, y, q_xi, q_Xi, prior, m)
            elbo_history.append(elbo.item())
            grad_omega = grad_ngd_bonnet(X, y, samples, q_xi)
                
            with torch.no_grad():
                eta_lambda = (1 - lr) * eta_lambda + lr * grad_omega[0]
                eta_Lambda = (1 - lr) * eta_Lambda + lr * (grad_omega[1] + eta_p_Lambda)
            q_cov = -1 / eta_Lambda / 2
            q_xi = q_cov * eta_lambda
            q_Xi = q_cov + q_xi * q_xi
            if proj is not None:
                with torch.no_grad():
                    q_xi, q_Xi = projection(q_xi, q_Xi, proj[0], proj[1])
            q_xi = q_xi.detach().requires_grad_(True)
            q_Xi = q_Xi.detach().requires_grad_(True)
        
        if iteration % 10 == 0:
            print(f"Iteration {iteration}: ELBO = {elbo.item():.4f}, KL = {kl.item():.4f}, Log-Likelihood = {log_likelihood.item():.4f}")
    
    torch.save({'mu': q_xi.detach().cpu(), 'cov': (q_Xi - q_xi * q_xi).detach().cpu(), 'elbo_history': elbo_history}, 
               f'model/cifar_ngd_proj={proj}_d={d}_lr={lr0}_id={id}.pt')
    return elbo_history, q_xi, q_Xi - q_xi * q_xi

def experiment_cifar():
    torch.manual_seed(42)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))
    ])
    cifar_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    indices = [i for i, (_, label) in enumerate(cifar_dataset) if label in [3, 5]]
    filtered_dataset = Subset(cifar_dataset, indices)
    Xs = torch.stack([x for x, _ in filtered_dataset]).to(device)
    ys = torch.tensor([1 if y == 3 else -1 for _, y in filtered_dataset], device=device)
    data_loader = DataLoader(TensorDataset(Xs, ys), batch_size=2000, shuffle=True)
    global num_points
    num_points = len(filtered_dataset)
    global d
    d = filtered_dataset[0][0].shape[0]

    num_iter = 3000
    m = 2000
    prior = dist.MultivariateNormal(torch.zeros(d, device=device), torch.eye(d, device=device))
    for id in range(5):
        q_mu0 = torch.nn.Parameter(torch.randn(d, device=device))
        q_cov_sq0 = torch.nn.Parameter(torch.exp(torch.randn(d, device=device)))

        for lr in [5e-3, 2e-3, 1e-3, 5e-4, 2e-4, 1e-4]:
            elbo_history, q_mu, q_cov = true_parameter_gd(data_loader, prior, m, copy.deepcopy(q_mu0), copy.deepcopy(q_cov_sq0), num_iter, lr, ('proj', torch.tensor([20], device=device)), id)
            elbo_history, q_mu, q_cov = true_parameter_gd(data_loader, prior, m, copy.deepcopy(q_mu0), copy.deepcopy(q_cov_sq0), num_iter, lr, ('prox', ), id)
            elbo_history, q_mu, q_cov = true_parameter_ngd(data_loader, prior, m, copy.deepcopy(q_mu0), copy.deepcopy(q_cov_sq0), num_iter, lr, None, id)
            elbo_history, q_mu, q_cov = true_parameter_ngd(data_loader, prior, m, copy.deepcopy(q_mu0), copy.deepcopy(q_cov_sq0), num_iter, lr, (4, 20), id)
            elbo_history, q_mu, q_cov = true_parameter_ngd(data_loader, prior, m, copy.deepcopy(q_mu0), copy.deepcopy(q_cov_sq0), num_iter, lr, (3, 15), id)
            elbo_history, q_mu, q_cov = true_parameter_ngd(data_loader, prior, m, copy.deepcopy(q_mu0), copy.deepcopy(q_cov_sq0), num_iter, lr, (2, 10), id)

def plot_results_convergence(mode):
    d = 3072
    num_show = 3000
    if mode == 0:
        gd_models = ['prox', 'proj']
        ngd_models = ['None', (4, 20)]
    elif mode == 1:
        gd_models = []
        ngd_models = ['None', (4, 20), (3, 15), (2, 10)]
    gd_lrs = [2e-4]
    ngd_lrs = [1e-3]
    results = {}
    for gd_model in gd_models:
        results[gd_model] = {} 
        for lr in gd_lrs:
            results[gd_model][lr] = []
            for id in range(5):
                results[gd_model][lr].append([-i for i in torch.load(f'cifar/cifar2_gd_d={d}_lr={lr}_method={gd_model}_id={id}.pt')['elbo_history']])
            results[gd_model][lr] = np.array(results[gd_model][lr])
            results[gd_model][lr] = results[gd_model][lr].reshape(5, 3000, 5).mean(axis=2)
    for ngd_model in ngd_models:
        results[ngd_model] = {}
        for lr in ngd_lrs:
            results[ngd_model][lr] = []
            for id in range(5):
                results[ngd_model][lr].append([-i for i in torch.load(f'cifar/cifar2_ngd_proj={ngd_model}_d={d}_lr={lr}_id={id}.pt')['elbo_history']])
            results[ngd_model][lr] = np.array(results[ngd_model][lr])
            results[ngd_model][lr] = results[ngd_model][lr].reshape(5, 3000, 5).mean(axis=2)
            
    rep_indices = {}
    for model in gd_models:
        rep_indices[model] = {}
        for lr in gd_lrs:
            elbos_trimmed = results[model][lr][:, :num_show]
            elbo_lower = np.percentile(elbos_trimmed, 25, axis=0)
            elbo_upper = np.percentile(elbos_trimmed, 75, axis=0)
            median_elbo = np.median(elbos_trimmed, axis=0)
            diff_sum = np.sum(np.abs(elbos_trimmed - median_elbo), axis=1)
            rep_index = np.argmin(diff_sum)
            rep_indices[model][lr] = rep_index
            
            plot_elbo = elbos_trimmed[rep_index]
            
            iterations = np.arange(1, 1 + num_show)
            
            if model == 'proj':
                plt.plot(iterations, plot_elbo, label=rf'Proj-SGD')
            elif model == 'prox':
                plt.plot(iterations, plot_elbo, label=rf'Prox-SGD')
            plt.fill_between(iterations, elbo_lower[:num_show], elbo_upper[:num_show], alpha=0.3)
            
    for model in ngd_models:
        rep_indices[model] = {}
        for lr in ngd_lrs:
            elbos_trimmed = results[model][lr][:, :num_show]
            elbo_lower = np.percentile(elbos_trimmed, 25, axis=0)
            elbo_upper = np.percentile(elbos_trimmed, 75, axis=0)
            median_elbo = np.median(elbos_trimmed, axis=0)
            diff_sum = np.sum(np.abs(elbos_trimmed - median_elbo), axis=1)
            rep_index = np.argmin(diff_sum)
            rep_indices[model][lr] = rep_index
            
            plot_elbo = elbos_trimmed[rep_index]
            
            iterations = np.arange(1, 1 + num_show)
            
            if model == 'None':
                if mode == 0:
                    plt.plot(iterations, plot_elbo, label=rf'SNGD')
                elif mode == 1:
                    plt.plot(iterations, plot_elbo, label=rf'U=$\infty$, D=$\infty$')
            else:
                if mode == 0:
                    plt.plot(iterations, plot_elbo, label=rf'Proj-SNGD')
                elif mode == 1:
                    plt.plot(iterations, plot_elbo, label=f'U={model[0]}, D={model[1]}')
            plt.fill_between(iterations, elbo_lower[:num_show], elbo_upper[:num_show], alpha=0.3)
            
    if mode == 0:
        plt.ylim(top=150000, bottom=20000)
    elif mode == 1:
        plt.ylim(top=150000, bottom=20000)
    plt.xlabel('Epoch')
    plt.ylabel('Negative ELBO')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'figure/cifar convergence {mode+1}.pdf')
    plt.clf()
    
def plot_results_robustness(mode):
    d = 3072
    def find_index(elbo_history, threshold):
        for i in range(len(elbo_history) - 5):
            if elbo_history[i] + elbo_history[i+1] + elbo_history[i+2] + elbo_history[i+3] + elbo_history[i+4] > threshold * 5:
                return i
        return len(elbo_history)
    threshold = -35000
    if mode == 0:
        gd_models = ['prox', 'proj']
        ngd_models = ['None', (4, 20)]
    elif mode == 1:
        gd_models = []
        ngd_models = ['None', (4, 20), (3, 15), (2, 10)]
    results = {}
    lrs = [5e-3, 2e-3, 1e-3, 5e-4, 2e-4, 1e-4]
    for gd_model in gd_models:
        results[gd_model] = {}
        for lr in lrs:
            results[gd_model][lr] = np.zeros(5)
            for id in range(5):
                elbo_history = torch.load(f'cifar/cifar2_gd_d={d}_lr={lr}_method={gd_model}_id={id}.pt')['elbo_history']
                results[gd_model][lr][id] = find_index(elbo_history, threshold)
    for ngd_model in ngd_models:
        results[ngd_model] = {}
        for lr in lrs:
            results[ngd_model][lr] = np.zeros(5)
            for id in range(5):
                elbo_history = torch.load(f'cifar/cifar2_ngd_proj={ngd_model}_d={d}_lr={lr}_id={id}.pt')['elbo_history']
                results[ngd_model][lr][id] = find_index(elbo_history, threshold)
    
    for model, lr_dict in results.items():
        lrs = sorted(lr_dict.keys())
        medians = []
        p25 = []
        p75 = []

        for lr in lrs:
            data = lr_dict[lr]
            medians.append(np.median(data))
            p25.append(np.percentile(data, 25))
            p75.append(np.percentile(data, 75))
            
        medians = np.array(medians)
        p25 = np.array(p25)
        p75 = np.array(p75)

        if mode == 0:
            if model == 'proj':
                plt.plot(lrs, medians, label=f'Proj-SGD', marker='^')
            elif model == 'prox':
                plt.plot(lrs, medians, label=f'Prox-SGD', marker='v')
            elif model == "None":
                plt.plot(lrs, medians, label=f'SNGD', marker='x')
            else:
                plt.plot(lrs, medians, label=f'Proj-SNGD', marker='o')
        elif mode == 1:
            if model == "None":
                plt.plot(lrs, medians, label=rf'U=$\infty$, D=$\infty$', marker='x')
            else:
                if model[0] == 4:
                    plt.plot(lrs, medians, label=f'U={model[0]}, D={model[1]}', marker='o')
                elif model[0] == 3:
                    plt.plot(lrs, medians, label=f'U={model[0]}, D={model[1]}', marker='s')
                elif model[0] == 2:
                    plt.plot(lrs, medians, label=f'U={model[0]}, D={model[1]}', marker='*')
        plt.fill_between(lrs, p25, p75, alpha=0.3)
            
    plt.xscale('log')
    plt.xlabel(r'$\gamma_0$')
    plt.ylabel('Number of Iterations')
    plt.legend(loc='upper right')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'figure/cifar robust {mode+1}.pdf')
    plt.clf()

experiment_cifar()
for mode in range(2):
    if mode == 1:
        custom_colors = ['#5B478B', '#356792', '#59A4A4', '#3C8358']
        plt.rcParams.update({'axes.prop_cycle': plt.cycler('color', custom_colors)})
    plot_results_convergence(mode)
    plot_results_robustness(mode)
