import numpy as np
import torch
import matplotlib.pyplot as plt
import imageio
import os

n_1 = 5
n_2 = 10

np.random.seed(1)
torch.manual_seed(1)
eps = 1E-9
scaling = 0.5
log_p_C = torch.tensor(np.log(np.random.dirichlet(np.ones(n_1) * scaling) + eps))
log_p_C -= log_p_C.logsumexp(-1, keepdim=True)
log_p_i = torch.tensor(np.log(np.random.dirichlet(np.ones(n_2) * scaling, size=n_1) + eps))
log_p_i -= log_p_i.logsumexp(-1, keepdim=True)

dir_name = f'results/{n_1}_{n_2}_05'
try:
    os.makedirs(dir_name)
except FileExistsError:
    pass

fig_1 = plt.figure(1)
fig_2 = plt.figure(2)
K_tot = 10
kl_lists = []
S_list = [1, 2, 3, 4, 5]
lrs = [0.01, 0.1, 0.1, 0.2, 0.25]
K_list = []
np.save(dir_name + '/learning_rates', lrs)
# best for S = 1, K = 20 is lr = 0.01
# best for S = 2, K = 10 is lr = 0.1
# best for S = 3, K = 10 is lr = 0.1
# best for S = 4, K = 10 is lr = 0.2
# best for S = 5, K = 4 is lr = 0.025
for S, lr in zip(S_list, lrs):
    np.random.seed(0)
    torch.manual_seed(0)

    phi_C = torch.tensor(np.random.normal(0, 0.1, (S, n_1)), requires_grad=True)
    phi_i = torch.tensor(np.random.normal(0, 0.1, (S, n_1, n_2)), requires_grad=True)
    opt = torch.optim.SGD(params=[phi_C] + [phi_i], lr=lr)
    K = K_tot # // S
    K_list.append(K)
    print(f"S = {S}, K = {K}, lr = {lr}")

    kl_list = []
    fig_names = []
    for iter_ in range(20001):
        loss = 0

        log_q_C = (phi_C - torch.logsumexp(phi_C, dim=-1, keepdim=True))
        log_q_i = (phi_i - torch.logsumexp(phi_i, dim=-1, keepdim=True))

        C = torch.multinomial(log_q_C.exp(), K, replacement=True)
        L_hat = 0
        for s in range(S):
            i = torch.multinomial(log_q_i[s, C[s]].exp() + eps, 1, replacement=True).flatten()

            log_q_s = log_q_C[s, C[s]] + log_q_i[s, C[s], i]
            log_q = torch.logsumexp(log_q_C[:, C[s]] + log_q_i[:, C[s], i], dim=0) - np.log(S)

            log_p = log_p_C[C[s]] + log_p_i[C[s], i]

            log_f = log_p - log_q
            L_hat_s = torch.logsumexp(log_p - log_q, dim=-1) - np.log(K)
            L_hat += L_hat_s / S
            if K > 1:
                mean_exclude_signal = (torch.sum(log_f) - log_f) / (K-1.)
                control_variates = torch.logsumexp(log_f.view(-1, 1).repeat(1, K) - log_f.diag() + mean_exclude_signal.diag() - np.log(K), dim=0)
                loss += -torch.sum((L_hat_s - control_variates).detach() * log_q_s - log_q, dim=0) / S
            else:
                # VIMCO is not applicable
                loss += -torch.sum(L_hat_s.detach() * log_q_s - log_q, dim=0) / S

        # plot before parameter update
        if (iter_ % 2000) == 0:
            if iter_ > 0:
                print("Iter: ", iter_, "KL = ", kl_list[-1].item())

            plt.figure(2)
            max_ = np.max(log_p_C.exp().numpy()) + 0.2
            plt.subplot(311)
            plt.ylim(0, max_)
            plt.bar(np.arange(n_1), log_p_C.exp().detach().numpy(), color='r')
            plt.subplot(312)
            plt.ylim(0, max_)
            for s in range(S):
                plt.bar(np.arange(n_1), log_q_C[s].exp().detach().numpy() / S, alpha=0.5)
            plt.subplot(313)
            plt.ylim(0, max_)
            plt.bar(np.arange(n_1), np.sum(log_q_C.exp().detach().numpy(), axis=0) / S, color='Black', align='edge',
                    width=0.4)
            plt.bar(np.arange(n_1), log_p_C.exp().detach().numpy(), color='r', align='edge', width=-0.4)
            plt.title(f"{iter_}")
            for k in range(1, 5):
                # five frames per batch
                filename = f"figs/{iter_ + k}.png"
                fig_names.append(filename)
                plt.savefig(filename)
            plt.close()

        loss.backward()
        # take step
        opt.step()
        # reset gradients
        opt.zero_grad()

        with torch.no_grad():
            kl = log_p_C.exp() @ torch.einsum('ci, ci -> c', log_p_i.exp(),
                               (log_p_C.view((n_1, 1)) + log_p_i -
                                torch.logsumexp(log_q_C.view((S, n_1, 1)) + log_q_i - np.log(S), dim=0))
                               )
        kl_list.append(kl)

    np.save(dir_name + f"/upper_dist_{S}", phi_C.detach().numpy())
    np.save(dir_name + f"/lower_dist_{S}", phi_i.detach().numpy())
    kl_lists.append(kl_list)
    with imageio.get_writer(f'S_{S}.gif', mode='I') as writer:
        for filename in fig_names:
            image = imageio.v2.imread(filename)
            writer.append_data(image)
            os.remove(filename)

np.save(dir_name + "/kl_lists", kl_lists)
np.save(dir_name + "/K_list", K_list)
plt.figure(1)
for i, s in enumerate(S_list):
    plt.plot(kl_lists[i], label=f"S={s} (K={K_list[i]})", alpha=0.5)
plt.title(f'n_1={n_1}, n_2={n_2}')
plt.legend(loc='best')
plt.savefig(f'{n_1}_{n_2}_05_seed_1_kl_curves')
plt.show()
