#!/usr/bin/env python3

"""
a separate logit for each (binary) class.
(i.e., softmax)
sm(x,y) = [e^x0 / (e^x0 + e^x1), e^x1 / (e^x0 + e^x1)]
positive class is associated with second logit
to be combined with L2 regularization

Not yet finished
"""

from jax import random
from jax import numpy as jnp

from fairgym.utils.misc import _augment_feat_array
from fairgym.utils.logistic_regression_jax import LogisticRegression


def _logits(params, aug_feat_array):
    print(params)
    print(params.shape)  # <- HERE
    return jnp.einsum("ij,jk->ik", aug_feat_array, params)


def _softmax_fn(logits):
    """
    with added numerical stability (i.e., not blowing up)
    """

    n = logits.shape[0]

    # subtract max logit from each pair, making all values negative
    logits = logits - logits.max(axis=-1).reshape((n, 1))

    x = jnp.exp(logits)
    return x / jnp.sum(axis=-1)


def _sample_loss(logits, labels, sample_weight):

    n = labels.shape[0]

    # pair labels with opposite value
    labels = jnp.dstack((labels, 1 - labels)).reshape((n, 2))

    return np.einsum("ij,ij->i", labels, -jnp.log(_softmax_fn(logits)))


def _loss(logits, labels, sample_weight):
    return jnp.einsum("i->", _sample_loss(logits, labels, sample_weight))


def _nat_loss_grad(params, aug_feat_array, labels, sample_weight):
    """
    Gradients given by natural gradient descent

    Parameters
    ----------
    params: array-like, shape (n_features,)
        Bias column, then weights

    aug_feat_array: array-like, shape (n_samples, n_features)
        Data matrix, with column of 1s at beginning

    labels: array-like, shape (n_samples,)
        true labels

    Returns
    -------
    grad: array-like, shape (n_features,)
    """

    n = aug_feat_array.shape[0]

    logits = _logits(params, aug_feat_array)
    p = _softmax_fn(logits)
    dd_nat = (_softmax_fn(logits) - labels) / (p[:, 0] * p[:, 1])

    # allow sample_weight to weight samples
    # return weighted average of sample_loss gradients
    loss_grad = jnp.einsum("ij,i,i->j", aug_feat_array, dd_nat, sample_weight) / n

    return loss_grad


class NaturalLogisticRegression(LogisticRegression):
    """
    Use natural gradient descent and double the parameters with
    L2 regularization to deal with numerical issues, as per
    "Symmetric (Optimistic) Natural Policy Gradient for
    Multi-Agent Learning with Parameter Convergence"
    """

    def __init__(
        self, random_state=0, max_iter=100, learn_rate=1e-1, stopping_update_size=1e-5
    ):
        self.key = random.PRNGKey(random_state)
        self.max_iter = max_iter
        self.learn_rate = learn_rate
        self.stopping_update_size = stopping_update_size

        self.params = None

        self.loss_grad_method = _nat_loss_grad

    @staticmethod
    def params_shape(num_dimensions):
        return (num_dimensions + 1, 2)

    def predict_proba(self, feat_array):
        """
        return probability of label=1
        """

        aug_feat_array = _augment_feat_array(feat_array)

        logits = _logits(self.params, aug_feat_array)

        # Pr(Y=1)
        return _logistic_fn(logits)


if __name__ == "__main__":
    N = 1000
    X = jnp.linspace(0, 1, N)

    key = random.PRNGKey(0)
    y = random.bernoulli(key, p=X)

    clf = NaturalLogisticRegression(
        random_state=0, max_iter=1000, learn_rate=1e0, stopping_update_size=1e-5
    )
    import time

    now = time.time_ns()
    clf.fit(X.reshape((N, 1)), y)
    print((time.time_ns() - now) / 1e9, "seconds")
    print(clf.loss(X, y), "loss")
