
import torch
from adacat import Adacat
from torch.distributions import Categorical, MixtureSameFamily, Normal, Beta
from tn import TruncatedNormal
import matplotlib.pyplot as plt

k = 10

params_ns = torch.nn.Parameter(torch.zeros(k * 2))
params_s = torch.nn.Parameter(torch.zeros(k * 2))

mix = Categorical(torch.ones(2,))
comp = Beta(torch.tensor([2., 8.]), torch.tensor([8., 2.]))
gmm = MixtureSameFamily(mix, comp)

optim_ns = torch.optim.Adam([params_ns], lr=0.0003)
optim_s = torch.optim.Adam([params_s], lr=0.0003)

n_its = 10000
bs = 200

def visualize(ax, q_ns, q_s, p, n_its):
    xs = torch.linspace(0., 1., 1000)

    qs_ns = q_ns.log_prob(xs.unsqueeze(-1)).exp().detach().numpy()
    qs_s = q_s.log_prob(xs.unsqueeze(-1)).exp().detach().numpy()
    ps = p.log_prob(xs).exp().detach().numpy()

    xs = xs.numpy()

    ax.clear()
    ax.plot(xs, ps,    "--", color="black", alpha=0.65, label="Target Density")
    ax.plot(xs, qs_s,  "-",  color="blue",  alpha=0.65, label="AdaCat w/ smoothing")
    ax.plot(xs, qs_ns, "-",  color="red",   alpha=0.65, label="AdaCat w/o smoothing")
    ax.set_ylim(0., 5.)
    ax.set_xlim(0., 1.)
    ax.set_ylabel("p(x)")
    ax.set_xlabel("x")
    ax.grid()
    ax.legend()

fig, ax = plt.subplots(figsize=(5, 3), dpi=200)

for _ in range(n_its):
    if _ % 50 == 0 or _ == n_its - 1:
        visualize(ax, Adacat(params_ns.unsqueeze(0)), Adacat(params_s.unsqueeze(0)), gmm, _ / n_its)
        fig.savefig("figs_1d/adacat_1d_{:04d}.png".format(_), bbox_inches="tight")

    s = gmm.sample((bs, 1))

    d_ns = Adacat(params_ns.unsqueeze(0).unsqueeze(0).expand(bs, 1, -1))
    loss_ns = -d_ns.log_prob(s).mean()
    optim_ns.zero_grad()
    loss_ns.backward()
    optim_ns.step()

    d_s = Adacat(params_s.unsqueeze(0).unsqueeze(0).expand(bs, 1, -1))
    smd  = TruncatedNormal(loc=s, scale=torch.ones_like(s) * 0.001, a=torch.zeros_like(s), b=torch.ones_like(s))
    loss_s = -d_s.log_prob(smd).mean()
    optim_s.zero_grad()
    loss_s.backward()
    optim_s.step()

    print("\rIt: {}, ns loss: {:.3f}, s loss: {:.3f}".format(_, float(loss_ns), float(loss_s)), end="")
print()
print("Done!!")


