import itertools
import matplotlib.pyplot as plt
import numpy as np

"""
What is this?
=============
Plot NCE loss values of two representations f1 and f2 in Ash et al. (2021).

Purpose
=======
Ash et al. (2021) claims that f2 is better than f1 while NCE loss prefers f2 to f1
only when #NS / #class takes a middle-ranged value. See their Figure 2.
We will see if this is true when the loss is logistic.

Usage
=====
Just run
```
python comparison_ash.py
```
then you will get a plot whose horizontal axis is #NS and vertical axis is NCE loss value.

Reference
=========
Ash et al. (2021): https://arxiv.org/abs/2106.09943
"""


# reference: https://stackoverflow.com/a/46378809
def multinomial(lst):
    res, i = 1, sum(lst)
    i0 = lst.index(max(lst))
    for a in lst[:i0] + lst[i0+1:]:
        for j in range(1,a+1):
            res *= i
            res //= j
            i -= 1
    return res


# compute contrastive loss of f2
def contrastive_loss(c=10, k=1, eps=0.1):
    def _subloss(v1, v2, v3):
        loss = 0
        for (k1, k2) in itertools.product(range(k), range(k)):
            if k2 > k - k1:
                break
            k3 = k - k1 - k2
            coeff_comb = multinomial([k1, k2, k3])
            coeff_prob = 0.5 * (1 / c) ** (k1 + k2) * (1 - 2 / c) ** k3
            loss_val = np.log(1 + k1 * np.exp(v1) + k2 * np.exp(v2) + k3 * np.exp(v3))
            loss += coeff_comb * coeff_prob * loss_val
        return loss

    l1 = _subloss(0, -2 * eps * (1 + eps), -(1 + eps) ** 2)
    l2 = _subloss(2 * eps * (1 + eps), 0, -(1 + eps) * (1 - eps))
    l3 = _subloss((1 + eps) ** 2, (1 + eps) * (1 - eps), 0)

    l4 = _subloss(0, 2 * eps * (1 - eps), -(1 - eps) ** 2)
    l5 = _subloss(-2 * eps * (1 - eps), 0, -(1 + eps) * (1 - eps))
    l6 = _subloss((1 - eps) ** 2, (1 + eps) * (1 - eps), 0)

    l = 1 / c * (l1 + l2 + l4 + l5) + (1 - 2 / c) * (l3 + l6)

    return l


def ovs_bound(c=10, k=1, eps=0.1):
    loss = contrastive_loss(c, k, eps)
    ovs_coeff = np.log(1 + c * np.exp(2)) / np.log(1 + k * np.exp(-2))
    return ovs_coeff * loss


if __name__ == "__main__":
    # number of classes
    c = 40

    # number of negative samples
    k = np.arange(1, 100)

    # error parameter (in definition of f2)
    eps = 0.7
    # eps = 0.35

    ovs_f1 = [contrastive_loss(c, _k,   0) for _k in k]
    ovs_f2 = [contrastive_loss(c, _k, eps) for _k in k]
    # ovs_f1 = [ovs_bound(c, _k,   0) for _k in k]
    # ovs_f2 = [ovs_bound(c, _k, eps) for _k in k]

    plt.plot(k, ovs_f1, label="f1 (bad)")
    plt.plot(k, ovs_f2, label="f2 (good)")

    plt.xlabel("# of negative samples")
    plt.ylabel("contrastive loss value")
    plt.grid()
    plt.legend()

    plt.show()
