import torch
import torch.distributions as D
import torch.nn as nn

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm

class LipschitzFunction(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LipschitzFunction, self).__init__()
        self.lin1 = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.Softplus()
        self.lin3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = x.float()
        x = self.lin1(x)
        x = self.relu1(x)
        x = self.lin3(x)
        return x




# define a 2D GMM (bivariate normal distributions)
mix = D.Categorical(torch.Tensor([0.1, 0.2, 0.7]))
comp = D.Independent(D.Normal(torch.Tensor([[0, 0],
                                            [20, 20],
                                            [-10, 20]]),
                              torch.Tensor([[2, 2],
                                            [3, 3],
                                            [1, 1.5]])), 1)
gmm = D.MixtureSameFamily(mix, comp)

x = gmm.sample()

# initialize a Gaussian
# mean = torch.tensor([0.0, 0.0])
# var = torch.tensor([1.0, 1.0])
# mean = torch.nn.Parameter(torch.tensor([15.0, 10.0]), requires_grad=True)
mean = torch.nn.Parameter(torch.tensor([-10.0, -10.0]), requires_grad=True)
var = torch.nn.Parameter(torch.tensor([5.0, 5.0]), requires_grad=False)
# gau = D.Independent(D.Normal(mean, var), 1)
gau = D.Normal(mean, var)

y = gau.sample()

print(' mean = ', mean.data, ' variance = ', var.data)

# minimize KL
lr = 0.01
optimizer = torch.optim.Adam([{'params': mean, 'lr': lr},
                              {'params': var, 'lr': lr}])

n_step = 0

while n_step < 20000:
    print(' step = ', n_step)

    # KL divergence, I-projection, KL[q || gmm] = E_q[log q] - E_q[log gmm]
    x = gau.rsample()
    gau_log_prob = gau.log_prob(x).sum()
    gmm_log_prob = gmm.log_prob(x)
    loss = gau_log_prob - gmm_log_prob

    # # inverse KL divergence, M-projection, KL[gmm || q] = E_gmm[log gmm] - E_gmm[log q]
    # x = gmm.sample()
    # gau_log_prob = gau.log_prob(x).sum()
    # gmm_log_prob = gmm.log_prob(x)
    # loss = gmm_log_prob - gau_log_prob

    # # wasserstein distance
    # x = gmm.sample()
    # y = gau.rsample()
    #
    # # phi function optimization
    # lipschitz_phi = LipschitzFunction(2, 8)
    # phi_optimiser = torch.optim.SGD(lipschitz_phi.parameters(), lr=0.001)
    # wasserstein_threshold = 0.01
    #
    # lipschitz_phi.train()
    # n_phi_steps = 10
    # for i in range(n_phi_steps):
    #     objective = - torch.nn.functional.l1_loss(lipschitz_phi(x), lipschitz_phi(y.detach()))
    #
    #     # gradient norm constraint
    #     eps = torch.rand(x.shape)
    #     xy = eps * x + (1 - eps) * y.detach()
    #     xy.requires_grad = True
    #     phi_xy = lipschitz_phi(xy)
    #     gradients = torch.autograd.grad(phi_xy, xy, create_graph=True, retain_graph=True, only_inputs=True,
    #                         grad_outputs=torch.ones(phi_xy.size()))[0]
    #     f_gradient_norm = gradients.norm(2)
    #
    #     penalty = ((f_gradient_norm - 1) ** 2).mean()
    #
    #     objective = objective + 20 * penalty
    #
    #     # optimize
    #     phi_optimiser.zero_grad()
    #     # objective.backward(retain_graph=True)
    #     objective.backward()
    #     phi_optimiser.step()
    #
    #     if wasserstein_threshold is not None:
    #         # Gradient Norm
    #         params = lipschitz_phi.parameters()
    #         grad_norm = torch.cat([p.grad.data.flatten() for p in params]).norm()
    #
    #         if grad_norm < wasserstein_threshold:
    #             break
    #
    # lipschitz_phi.eval()
    #
    # # calculate distance
    # loss = torch.abs(lipschitz_phi(x) - lipschitz_phi(y))

    # optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    n_step = n_step + 1

updated_mean = mean.detach()
updated_var = var.detach()

print(' mean = ', mean.detach(), ' variance = ', var.detach())

# plot 2D gmm and gau

# display predicted scores by the model as a contour plot
xx = np.linspace(-20.0, 30.0)
yy = np.linspace(-20.0, 40.0)
X, Y = np.meshgrid(xx, yy)
XX = np.array([X.ravel(), Y.ravel()]).T
Z = -gmm.log_prob(torch.from_numpy(XX))
Z = Z.reshape(X.shape).numpy()

CS = plt.contourf(
    X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10)
)

Z2 = -gau.log_prob(torch.from_numpy(XX)).sum(1).detach().numpy() # Compute the log-likelihood of each sample.
Z2 = Z2.reshape(X.shape)
# plt.contour(
#     X, Y, Z2, colors='magenta', linewidths=3, levels=np.logspace(0, 3, 10)
# )

plt.contour(
    X, Y, Z2, colors='red', linewidths=3, levels=np.logspace(0, 3, 10)
)

CB = plt.colorbar(CS, shrink=0.8, extend="both")
CB.ax.tick_params(labelsize=20)
plt.scatter(updated_mean[0], updated_mean[1], s=300, c="red", marker="*")
# plt.scatter(mean.data[0], mean.data[1], s=300, c="magenta", marker="*")

# plt.title("I-projection")
# plt.title("M-projection")
# plt.title("Wasserstein")
# plt.title("prior", fontsize=40)
plt.axis("tight")
plt.tick_params(labelsize=20)
fig = plt.gcf()
fig_width = 10
fig_height = 8
fig.set_size_inches(fig_width, fig_height)


# print("============================================================================================")
plt.savefig('./I-projection1.pdf', bbox_inches='tight', pad_inches=0)
# # plt.savefig(fig_save_path)
# # print("figure saved at : ", fig_save_path)
# print("============================================================================================")

plt.show()

