# %% [markdown]
# # Group DRO Synthetic Experiment

# %%
import numpy as np
from solver import DROInstance, sagawa_et_al_solve, EXP3P_solve
from matplotlib import pyplot as plt
np.random.seed(5555)

# %% [markdown]
# ## Synthetic Classification DRO instance
# This experiment set up is adopted from (Namkoong and Duchi 2016). For group $i$, we set
# - $\theta_i^* \in \mathbb{R}^d$: true classifier parameter for group $i$
# - $a \sim N(0, I)$: feature vector
# - $b = sign(a^\top \theta^*_i)$ (flipped with prob 0.1)
#
# We use Hinge loss function: $\ell(\theta; (a, b)) = (1 - ba^\top \theta)_+$.
# The feasible region is the L2-ball with radious $D = 10$.

# %%
def generate_data(theta_star):
    "generate (a,b) with tru classifier theta_star"
    n = theta_star.shape[0]
    a = np.random.normal(np.zeros(n))
    b = np.sign(a.T @ theta_star)
    if np.random.binomial(1, p=.1):
        b = -b # flip
    return (a, b)

# def generate_data_at_once(theta_star, size):
#     """
#     generate (a,b) with tru classifier theta_star
#     output contains `size` rows of (a, b)
#     """
#     n = theta_star.shape[0]
#     A = np.random.normal(np.zeros(n), size=(size, n))
#     b = np.sign(np.dot(A, theta_star))
#     flip = (np.random.binomial(size, p=.1) == 1)
#     b[flip] = -b[flip]

def hinge_loss(x, A, b):
    """
    compute hinge loss
     x : input vector (n,)
     A : feature matrix (k, n)
     b : label vector (k, )
    """
    return np.maximum(0, 1 -b * np.dot(A, x)).mean()

# %%
class SyntheticDROInstance(DROInstance):

    def __init__(self, theta_star, D, group_data) -> None:
        self.theta_star = theta_star
        self.D = D
        self.m = theta_star.shape[0]
        self.n = theta_star.shape[1]
        self.group_data = group_data  # data for compute obj
        self.name = 'synthetic'

    def obj(self, theta):
        return max(self._group_obj(theta, i) for i in range(self.m))

    def _group_obj(self, theta, i):
        output = 0
        A, b = self.group_data[i]
        return hinge_loss(theta, A, b)

    def sample(self, i):
        A, b = self.group_data[i]
        n_sample = b.shape[0]
        j = np.random.randint(n_sample)
        return A[j,:], b[j]
        # return generate_data(self.theta_star[i, :])

    def loss(self, theta, z):
        a, b = z
        return hinge_loss(theta, a, b)

    def loss_grad(self, theta, z):
        a, b = z
        if self.loss(theta, z) > 0:
            return -b * a
        else:
            return np.zeros(self.n)

    def proj(self, theta):
        if np.linalg.norm(theta) > self.D:
            return theta / np.linalg.norm(theta) * self.D
        else:
            return theta

def synthetic_DRO_instance(m, n, D, validation_size=1000):
    theta_star = np.zeros((m, n))
    for i in range(m):
        v = np.random.normal(np.zeros(n))
        theta_star[i, :] = v / np.linalg.norm(v)
    group_data = {}
    for i in range(m):
        A = []
        bs = []
        for j in range(validation_size):
            a, b = generate_data(theta_star[i,:])
            A.append(a)
            bs.append(b)
        group_data[i] = (np.array(A), np.array(bs))

    # print(validation_data[0][-1])
    return SyntheticDROInstance(theta_star, D, group_data)

# %%
def test():
    n = 500
    m = 10
    D = 10
    droinstance = synthetic_DRO_instance(m, n, D)
    group_data = droinstance.group_data

    import cvxpy as cp
    x = cp.Variable(n)
    t = cp.Variable()
    def cp_group_obj(x, i):
        A, b = group_data[i]
        n_sample = b.shape[0]
        return cp.sum(cp.pos(1 - cp.multiply(b, A @ x))) / n_sample

    prob = cp.Problem(cp.Minimize(t),
                    [cp_group_obj(x, i) <= t for i in range(m)] + [cp.norm(x) <= D])
    prob.solve()
    ans = x.value
    opt = t.value
    print(droinstance.obj(ans))
    print(opt)

    print("starting DRO...")
    print(droinstance.m)
    T = 100000
    SETTINGS = ['Sagawa_et_al', 'Hedge']
    obj_hists = {}

    for setting in SETTINGS:
        print(setting)
        if setting == 'Sagawa_et_al':
            alg = sagawa_et_al_solve
        else:
            alg = EXP3P_solve
        output, obj_hist = alg(droinstance, T,
                            theta0=np.zeros(n),
                            eta_t=lambda t: D * np.sqrt(1/t),
                            eta_q=np.sqrt(np.log(m)/(m*T)),
                            beta=np.sqrt(np.log(m/0.1) / (m * T)),
                            minibatch=5
                            )
        obj_hists[setting] = obj_hist

    # plot
    from matplotlib import pyplot as plt
    for setting in SETTINGS:
        obj_hist = obj_hists[setting]
        print(obj_hist[-1])
        x = np.array([a[0] for a in obj_hist])
        y = np.array([a[1] for a in obj_hist])
        plt.loglog(x, y - opt, label=setting)
    plt.legend()
    plt.grid(which='both')


if __name__ == '__main__':
    test()