import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
import torch.nn.functional as F
from tqdm import tqdm
import logging

custom_colors = ['darkorange', '#ba2406', '#5B478B', '#356792']
plt.rcParams.update({
    'font.family': 'serif',
    'text.usetex': True,
    'font.size': 1,
    'axes.titlesize': 21,
    'axes.labelsize': 21,
    'legend.fontsize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'figure.dpi': 400,
    'figure.figsize': [6, 4],      
    'axes.linewidth': 0.8,
    'lines.linewidth': 2.0,             
    'lines.markersize': 5,            
    'axes.grid': True,
    'grid.linestyle': '--',
    'grid.linewidth': 0.15,
    'legend.frameon': True,
    'legend.framealpha': 0.75,
    'legend.loc': 'best',
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.02,
    'axes.spines.top': True,
    'axes.spines.right': True,
    'axes.prop_cycle': plt.cycler('color', custom_colors)
})

modes = [(2.0, None), (0.4, None), (2.0, (4, 25))]
lr_lst_ngd = [0.5, 0.3, 0.1]
num_iter = 51
num_show = 51
sigmoid = nn.Sigmoid()

def projection(q_xi, q_Xi, U, D):
    q_cov = q_Xi - q_xi * q_xi
    q_xi_new = q_xi.clamp(min=-U, max=U)
    q_cov_new = q_cov.clamp(min=1/D, max=D)
    return (q_xi_new, q_cov_new + q_xi_new * q_xi_new)

def poisson_data(x):
    X = torch.tensor([[x]])
    w = torch.tensor([4.0])
    gamma = torch.exp(X @ w)
    y = torch.poisson(gamma)
    return X, y, w

def elbo_ngd(X, y, q_xi, q_Xi, prior):
    q_cov = q_Xi - q_xi * q_xi
    q_dist = dist.MultivariateNormal(q_xi, torch.diag_embed(q_cov))
    kl = dist.kl_divergence(q_dist, prior)
    
    quad_term = 0.5 * torch.einsum('ij,jk,ik->i', X, torch.diag_embed(q_cov), X)
    log_likelihood = y.t() @ X @ q_xi - torch.sum(torch.exp(X @ q_xi + quad_term))
        
    elbo = log_likelihood - kl
    return elbo, kl, log_likelihood

def grad_ngd_ad(log_likelihood, q_xi, q_Xi):
    return torch.autograd.grad(log_likelihood, [q_xi, q_Xi])

def true_parameter_ngd(X, y, prior, q_mu, q_cov_sq, num_iter, lr, proj):
    d = q_mu.shape[0]
    eta_p_Lambda = -torch.ones(d) / 2
    eta_Lambda = -1 / (q_cov_sq * q_cov_sq) / 2
    eta_lambda = -2 * eta_Lambda * q_mu
    elbo_history = []
    q_xi = q_mu.detach().requires_grad_(True)
    q_Xi = (q_cov_sq * q_cov_sq + q_mu * q_mu).detach().requires_grad_(True)
    for iteration in tqdm(range(1, 1 + num_iter)):
        elbo, kl, log_likelihood = elbo_ngd(X, y, q_xi, q_Xi, prior)
        grad_omega = grad_ngd_ad(log_likelihood, q_xi, q_Xi)
            
        with torch.no_grad():
            eta_lambda = (1 - lr) * eta_lambda + lr * grad_omega[0]
            eta_Lambda = (1 - lr) * eta_Lambda + lr * (grad_omega[1] + eta_p_Lambda)
        q_cov = -1 / eta_Lambda / 2
        q_xi = q_cov * eta_lambda
        q_Xi = q_cov + q_xi * q_xi
        if proj is not None:
            with torch.no_grad():
                q_xi, q_Xi = projection(q_xi, q_Xi, proj[0], proj[1])
        q_xi = q_xi.detach().requires_grad_(True)
        q_Xi = q_Xi.detach().requires_grad_(True)
        elbo_history.append(elbo.item())
        
        if iteration % 10 == 0:
            print(f"Iteration {iteration}: ELBO = {elbo.item():.4f}, KL = {kl.item():.4f}, Log-Likelihood = {log_likelihood.item():.4f}")
    
    return elbo_history

def experiment():
    torch.manual_seed(42)
    d = 1
    X, y, w0 = poisson_data(0.9)
    prior = dist.MultivariateNormal(torch.zeros(d), torch.eye(d))
    results = {}
    for id in range(10):
        results[id] = {}
        init = torch.distributions.Uniform(low=-3.0, high=0.0).sample((1, ))
        for mode in range(3):
            results[id][mode] = {}
            init_cov_sq = torch.sqrt(torch.tensor(modes[mode][0]))
            for lr_ngd in lr_lst_ngd:
                q_mu = torch.nn.Parameter(torch.tensor([init]))        
                q_cov_sq = torch.nn.Parameter(torch.tensor([init_cov_sq]))
                res = true_parameter_ngd(X, y, prior, q_mu, q_cov_sq, num_iter, lr_ngd, modes[mode][1])
                results[id][mode][lr_ngd] = np.array(res)

    f_star = max(results[id][mode][lr_ngd].max()
             for id in range(10)
             for mode in range(3)
             for lr_ngd in lr_lst_ngd) + 1e-4

    res = {}
    for mode in range(3):
        res[mode] = {}
        for lr_ngd in lr_lst_ngd:
            runs = [results[id][mode][lr_ngd] for id in range(10)]
            res[mode][lr_ngd] = np.vstack(f_star - runs)
    
    rep_indices = {}
    for mode in range(3):
        rep_indices[mode] = {}
        for lr in lr_lst_ngd:
            gaps = res[mode][lr][:, :num_show]
            """mask = gaps < 2e-4
            first_pos = mask.argmax(axis=1)
            first_pos[~mask.any(axis=1)] = 50
            # print(np.median(first_pos), np.percentile(first_pos, 25), np.percentile(first_pos, 75))"""
            gaps_lower = np.percentile(gaps, 25, axis=0)
            gaps_upper = np.percentile(gaps, 75, axis=0)
            median_gaps = np.median(gaps, axis=0)
            diff_sum = np.sum(np.abs(gaps - median_gaps), axis=1)
            rep_index = np.argmin(diff_sum)
            rep_indices[mode][lr] = rep_index
            
            plot_gaps = gaps[rep_index]
            iterations = np.arange(0, num_show)
            plt.plot(iterations, plot_gaps, label=f'step size={lr}')
            plt.fill_between(iterations, gaps_lower[:num_show], gaps_upper[:num_show], alpha=0.3)
        plt.yscale('log')
        plt.xlabel('Iteration')
        plt.ylabel('Function Value Gap')
        plt.legend()
        plt.ylim(bottom=5e-5, top=100000)
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f'figure/poisson {mode+1} new.pdf')
        plt.clf()
        
experiment()
