import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import newton
from scipy.special import logsumexp
from numba import prange
from abc import ABC, abstractmethod

np.random.seed(5555)
LOG_SIZE=500 # number of rows of output csv

class Hedge:
    def __init__(self, m: int, eta) -> None:
        self.m = m
        if np.isscalar(eta):
            self.eta = lambda t: eta # constant step size
        else:
            self.eta = eta
        self.t = 1
        self.logq = np.log(np.ones(m)/m)

    @property
    def q(self):
        q = np.exp(self.logq) + 1e-10 ## avoid real zero
        return q

    def play(self) -> int:
        i = np.random.choice(np.arange(self.m), p=self.q)  # sample arm w.r.t. q
        return i

    def feedback(self, g, eps=1e-8):
        "do a single step of Hedge with gradient g"
        self.logq += self.eta(self.t) * g
        # self.q = np.maximum(self.q, eps)
        # self.q = self.q / np.sum(self.q)
        # self.logq = self.logq - np.log(np.exp(self.logq).sum())
        self.logq = self.logq - logsumexp(self.logq)
        self.t += 1

# %%
class EXP3P(Hedge):
    "High probability version of Hedge"
    def __init__(self, m, eta, beta) -> None:
        super().__init__(m, eta)
        self.beta = beta

    def feedback(self, g, eps=1e-8):
        g = g + self.beta / self.q  # biasted loss estimator
        return super().feedback(g, eps=eps)

# %%
class TINF:
    def __init__(self, m: int, eta) -> None:
        self.m = m
        if np.isscalar(eta):
            self.eta = lambda t : eta # constant step size
        else:
            self.eta = eta
        self.q = np.ones(m) / m
        self.x = .5
        self.t = 1

    def play(self):
        i = np.random.choice(np.arange(self.m), p=self.q)  # sample arm w.r.t. q
        return i

    def newton(self, p):
        a = 1/np.sqrt(p)
        x = newton(func = lambda x: np.sum((a - x)**(-2)) - 1,
                   x0 = self.x,
                   fprime = lambda x: 2 * np.sum((a - x)**(-3))
                  )
        self.q = (a - x)**(-2)
        self.q = self.q / self.q.sum()
        self.x = x

    def feedback(self, g):
        p = self.q * (1 - self.eta(self.t) * np.sqrt(self.q) * g)**(-2)
        self.newton(p)
        self.t += 1


# %%
# simple test
def test_bandit():
    T = 10000
    m = 10
    means = np.arange(1, m+1)/m # means of arm rewards
    ALGS = ('Hedge', 'TINF', 'EXP3P')
    inst_regrets = {}

    for algname in ALGS:
        if algname == 'Hedge':
            # qalg = Hedge(m, eta=.1 * np.sqrt(np.log(m)/T), beta = None)  # Hedge
            qalg = Hedge(m,
                        eta=np.sqrt(np.log(m)/ (m * T))
                        )  # Hedge
            # qalg = Hedge(m, eta=lambda t: .1 * np.sqrt(np.log(m)/t), beta = .05 * np.sqrt(np.log(m)/T))  # Hedge
        if algname == 'TINF':
            qalg = TINF(m,
                        eta=np.sqrt(1/T)
                        )  # Tsallis-INF
            # qalg = TINF(m, eta=lambda t: .5 * np.sqrt(1/t))  # Tsallis-INF
        if algname == 'EXP3P':
            qalg = EXP3P(m,
                        eta=np.sqrt(np.log(m)/(m*T)),
                        beta=np.sqrt(np.log(m/0.1) / (m * T))
                        )

        inst_regret = np.empty(T)   # instantaneous regret

        for t in range(T):
            i = qalg.play()
            inst_regret[t] = means.max() - means[i]
            reward = np.random.binomial(n=1, p=means[i]) # reward
            g = np.zeros(m)
            g[i] = reward / qalg.q[i] # estimated gradient
            qalg.feedback(g)

        inst_regrets[algname] = inst_regret

    # plot
    for algname in ALGS:
        plt.plot(inst_regrets[algname].cumsum(), label=algname)
    plt.legend()
    plt.show()

class OGD:
    def __init__(self, theta0, proj, eta) -> None:
        self.theta = theta0
        self.proj = proj
        if np.isscalar(eta):
            self.eta = lambda t : eta # constant step size
        else:
            self.eta = eta
        self.t = 1
        self.n = theta0.shape[0]

    def feedback(self, grad_theta):
        # update for theta
        self.theta -= self.eta(self.t) * grad_theta
        self.theta = self.proj(self.theta)
        self.t += 1

    def play(self):
        return self.theta

# %% [markdown]
# ### Simple test of OGD
# minimize
# $$
#     \mathbf{E}_{z \sim N(\mu, I)}[\ell(\theta; z)] = \mathbf{E}_{z \sim N(\mu, I)}[\frac{1}{2}\| \theta - z \|^2]
# $$
# on the unit ball.

# %%
def loss(theta, z):
    return np.linalg.norm(theta - z) / 2

def loss_grad(theta, z):
    return theta - z

def proj_on_ball(theta):
    return theta / max(np.linalg.norm(theta), 1)

# %%
def test_OGD():
    T = 50000
    n = 10
    theta0 = proj_on_ball(np.ones(n))
    mu = np.random.normal(np.zeros(n))
    mu = proj_on_ball(mu)
    alg = OGD(theta0, proj_on_ball, eta=np.sqrt(1/(2*T)))
    output = np.zeros(n)
    obj_hist = np.zeros(T)
    for t in range(T):
        theta = alg.play()
        g = theta - np.random.normal(mu)
        alg.feedback(g)
        output = output * t / (t+1) + theta / (t+1)
        obj_hist[t] = loss(output, mu)
    plt.grid()
    plt.loglog(obj_hist)
    print(obj_hist[-1])

# %% [markdown]
# ## Abstract DRO instance class

# %%
class DROInstance(ABC):
    "Base DRO instance class"
    def __init__(self) -> None:
        pass

    @abstractmethod
    def obj(self, theta):
        'DRO objective'
        pass

    @abstractmethod
    def proj(self, theta):
        'projection onto the feasible region'
        pass

    @abstractmethod
    def sample(self, i):
        'sample z from the ith distribution'
        pass

    @abstractmethod
    def loss(self, theta, z):
        'global loss function'
        pass

    @abstractmethod
    def loss_grad(self, theta, z):
        'gradient of global loss function'
        pass

# %%
class SagawaEtAlGradEstimator:
    def __init__(self, droinstance, minibatch) -> None:
        self.droinstance = droinstance
        self.minibatch = minibatch

    def grad(self, theta, q):
        n = theta.shape[0]
        m = q.shape[0]
        i = np.random.randint(m)

        # initialize gradient
        t_grad = np.zeros(n)
        q_grad = np.zeros(m)

        # minibatch for stability
        for b in range(self.minibatch):
            z = self.droinstance.sample(i)
            t_grad += m * q[i] * self.droinstance.loss_grad(theta, z)
            q_grad[i] += m * self.droinstance.loss(theta, z)
        t_grad = t_grad / self.minibatch
        q_grad = q_grad / self.minibatch

        return t_grad, q_grad

class EXP3GradEstimator:
    def __init__(self, droinstance, minibatch) -> None:
        self.droinstance = droinstance
        self.minibatch = minibatch

    def grad(self, theta, q):
        n = theta.shape[0]
        m = q.shape[0]
        i = np.random.choice(np.arange(m), p=q)

        # initialize gradient
        t_grad = np.zeros(n)
        q_grad = np.zeros(m)

        # minibatch for stability
        for b in range(self.minibatch):
            z = self.droinstance.sample(i)
            t_grad += self.droinstance.loss_grad(theta, z)
            q_grad[i] += self.droinstance.loss(theta, z) / q[i]
        t_grad = t_grad / self.minibatch
        q_grad = q_grad / self.minibatch

        return t_grad, q_grad

def no_regret_solve(droinstance, talg, qalg, T, grad_est, trial=None, csv_writer=None):
    "solve DROinstance with talg, qalg, grad_est in T iteration"
    output = np.zeros(talg.n)
    objhist = []
    cycle = max(T // LOG_SIZE, 1) # store cycle
    for t in range(T):
        # play
        theta = talg.play()
        q = qalg.q

        # store average
        output = output * t / (t+1) + theta / (t+1)

        # estimate gradient
        t_grad, q_grad = grad_est.grad(theta, q)

        # feedback
        talg.feedback(t_grad)
        qalg.feedback(q_grad)

        # record objective value for every cycle
        if t % cycle == 0:
            obj_val = droinstance.obj(output)
            objhist.append((t, obj_val))
            if trial:
                # optuna trial
                trial.report(obj_val, t)
            if csv_writer:
                csv_writer.writerow([t, obj_val])

    return output, objhist

def sagawa_et_al_solve(droinstance, T, theta0, eta_t, eta_q, beta, minibatch=5, trial=None, csv_writer=None, **kwargs):
    m = droinstance.m
    talg = OGD(theta0, droinstance.proj, eta_t)
    qalg = EXP3P(m, eta_q, beta)
    grad_est = SagawaEtAlGradEstimator(droinstance, minibatch)
    return no_regret_solve(droinstance, talg, qalg, T, grad_est, trial, csv_writer)

def EXP3P_solve(droinstance, T, theta0, eta_t, eta_q, beta, minibatch=5, trial=None, csv_writer=None, **kwargs):
    m = droinstance.m
    talg = OGD(theta0, droinstance.proj, eta_t)
    qalg = EXP3P(m, eta_q, beta)
    grad_est = EXP3GradEstimator(droinstance, minibatch)
    return no_regret_solve(droinstance, talg, qalg, T, grad_est, trial, csv_writer)

def TINF_solve(droinstance, T, theta0, eta_t, eta_q, minibatch=5, trial=None, csv_writer=None, **kwargs):
    m = droinstance.m
    talg = OGD(theta0, droinstance.proj, eta_t)
    qalg = TINF(m, eta_q)
    grad_est = EXP3GradEstimator(droinstance, minibatch)
    return no_regret_solve(droinstance, talg, qalg, T, grad_est, trial, csv_writer)

# %% [markdown]
# ## DRO with synthetic data
#
# We use loss function
# $$
#     \ell(\theta; z) = \frac{1}{2}\| \theta - z \|_2^2, \quad \nabla \ell(\theta, z) = \theta - z
# $$
# and group distributions $P_i : z \sim N(\mu_i, I)$

# %%
class SyntheticDROInstance(DROInstance):
    def __init__(self, mu) -> None:
        self.mu = mu
        self.m = mu.shape[0]
        super().__init__()

    def _exp_loss(self, theta, mu):
        "expected loss"
        return (np.linalg.norm(theta)**2 - 2 * theta.T @ mu + theta.shape[0]) / 2

    def obj(self, theta):
        return max(self._exp_loss(theta, self.mu[i, :]) for i in range(self.m))

    def loss(self, theta, z):
        return np.linalg.norm(theta - z) / 2

    def loss_grad(self, theta, z):
        return theta - z

    def proj(self, theta):
        return theta / max(np.linalg.norm(theta), 1)

    def sample(self, i):
        return np.random.normal(self.mu[i,:])

def test_synthetic():
    # generate random DRO instance
    m = 2
    n = 10
    mu = np.zeros((m, n))
    for i in range(m):
        g = np.random.normal(np.zeros(n))
        mu[i, :] = g / np.linalg.norm(g)
    droinstance = SyntheticDROInstance(mu)

    # true solution via CVX
    import cvxpy as cp
    x = cp.Variable(n)
    t = cp.Variable()
    def cp_exp_loss(x, mu):
        return (cp.norm(x)**2 - 2 * x @ mu + n)/2
    prob = cp.Problem(cp.Minimize(t),
                    [cp_exp_loss(x, mu[i,:]) <= t for i in range(m)] + [cp.norm(x) <= 1])
    prob.solve()
    ans = x.value
    opt = t.value
    assert np.abs(droinstance.obj(ans) - opt) / opt < 1e-6

    T = 20000
    B = 10 # batch size
    SETTINGS = ['Sagawa_et_al', 'Hedge']
    obj_hists = {}
    theta_outs = {}
    theta0 = mu[0,:]

    for j in prange(len(SETTINGS)):
        setting = SETTINGS[j]
        if setting == 'Sagawa_et_al':
            alg = sagawa_et_al_solve
        else:
            alg = EXP3P_solve
        output, obj_hist = alg(
            droinstance, T, theta0,
            eta_t=lambda t: np.sqrt(1/(2*t)),
            eta_q=np.sqrt(np.log(m)/(m*T)),
            beta=np.sqrt(np.log(m/0.1) / (m * T)),
            minibatch=B
        )
        theta_outs[setting] = output
        obj_hists[setting]  = obj_hist

    plt.grid()
    for setting in SETTINGS:
        obj_hist = obj_hists[setting]
        x = np.array([a[0] for a in obj_hist])
        y = np.array([a[1] for a in obj_hist])
        # plt.loglog(x, (y - opt)/opt, label=setting)
        plt.plot(x, (y - opt)/opt, label=setting)
        plt.legend()
        print(obj_hists[setting][-1])

if __name__ == '__main__':
    print('test bandit')
    test_bandit()
    print('test OGD')
    test_OGD()
    print('test synthetic DRO')
    test_synthetic()
