import numpy as np

class Fairness_Learning:
    def __init__(self, A, b, c, lamb, gamma, beta, penalty_eps=0.0):
        self.m, self.dim = A.shape
        self.d = self.dim + 1
        self.gamma = gamma
        self.beta = beta
        self.lamb = lamb
        self.penalty_eps = penalty_eps
        self.A = A
        self.b = b
        self.c = c
        self._perm = np.arange(self.m)
        self._idx_ptr = 0

    def compute_losses(self, z):
        x = z[: self.dim]
        y = z[self.dim]
        Ax = self.A @ x
        pred_loss = np.log1p(np.exp(-self.b * Ax)).mean()
        adv_loss = np.log1p(np.exp(self.c * Ax * y)).mean()
        return pred_loss, adv_loss

    def _grad_lp_x_batch(self, A, b, z):
        x = z[: self.dim]
        Ax = A @ x
        exp_bAx = np.exp(b * Ax)
        p1 = 1.0 / (1.0 + exp_bAx)
        m_batch = A.shape[0]
        return -A.T @ (p1 * b / m_batch) + 2 * self.lamb * x

    def _grad_la_x_batch(self, A, c, z):
        x = z[: self.dim]
        y = z[self.dim]
        Ax = A @ x
        exp_c = np.exp(c * Ax * y)
        p2 = 1.0 / (1.0 + exp_c)
        sig_u = 1.0 - p2
        m_batch = A.shape[0]
        return A.T @ (c * y * sig_u / m_batch)

    def _grad_la_y_batch(self, A, c, z):
        x = z[: self.dim]
        y = z[self.dim]
        Ax = A @ x
        exp_c = np.exp(c * Ax * y)
        p2 = 1.0 / (1.0 + exp_c)
        sig_u = 1.0 - p2
        m_batch = A.shape[0]
        return np.sum(Ax * c * sig_u / m_batch) + 2 * self.gamma * y

    def _grad_x_batch(self, A, b, c, z, beta=None):
        beta_val = self.beta if beta is None else beta
        x = z[: self.dim]
        y = z[self.dim]
        Ax = A @ x
        exp_bAx = np.exp(b * Ax)

        p1 = 1 / (1 + exp_bAx)
        exp_c_multt_y = np.exp(c * Ax * y)
        p2 = 1 / (1 + exp_c_multt_y)
        m_batch = A.shape[0]
        gx = -A.T @ (p1 * b / m_batch) + A.T @ (c * y * p2 * beta_val / m_batch) + 2 * self.lamb * x
        return gx

    def grad_lp_x(self, z):
        return self._grad_lp_x_batch(self.A, self.b, z)

    def grad_la_x(self, z):
        return self._grad_la_x_batch(self.A, self.c, z)

    def grad_la_y(self, z):
        return self._grad_la_y_batch(self.A, self.c, z)

    def _grad_y_batch(self, A, b, c, z, beta=None):
        beta_val = self.beta if beta is None else beta
        x = z[: self.dim]
        y = z[self.dim]
        Ax = A @ x
        
        exp_c_multt_y = np.exp(c * Ax * y)
        p2 = 1 / (1 + exp_c_multt_y)
        m_batch = A.shape[0]
        gy = -np.sum(Ax * c * beta_val * p2 / m_batch) + 2 * self.gamma * y
        return gy

    def _penalty_beta(self, A, c, z, rho, eps):
        Ax = A @ z[: self.dim]
        L_adv = np.log1p(np.exp(c * Ax * z[self.dim])).mean()
        hinge = max(L_adv - eps, 0.0)
        return self.beta + rho * hinge

    def _grad_x_batch_penalty(self, A, b, c, z, rho, eps):
        beta_val = self._penalty_beta(A, c, z, rho, eps)
        return self._grad_x_batch(A, b, c, z, beta=beta_val)

    def _grad_y_batch_penalty(self, A, b, c, z, rho, eps):
        beta_val = self._penalty_beta(A, c, z, rho, eps)
        return self._grad_y_batch(A, b, c, z, beta=beta_val)

    def grad_x(self, z):
        return self._grad_x_batch(self.A, self.b, self.c, z)

    def grad_y(self, z):
        return self._grad_y_batch(self.A, self.b, self.c, z)

    def grad_x_stochastic(self, z, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self.grad_x(z)
        A_batch = self.A[idx]
        b_batch = self.b[idx]
        c_batch = self.c[idx]
        return self._grad_x_batch(A_batch, b_batch, c_batch, z)

    def grad_y_stochastic(self, z, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self.grad_y(z)
        A_batch = self.A[idx]
        b_batch = self.b[idx]
        c_batch = self.c[idx]
        return self._grad_y_batch(A_batch, b_batch, c_batch, z)

    def grad_lp_x_stochastic(self, z, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self.grad_lp_x(z)
        A_batch = self.A[idx]
        b_batch = self.b[idx]
        return self._grad_lp_x_batch(A_batch, b_batch, z)

    def grad_la_x_stochastic(self, z, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self.grad_la_x(z)
        A_batch = self.A[idx]
        c_batch = self.c[idx]
        return self._grad_la_x_batch(A_batch, c_batch, z)

    def grad_la_y_stochastic(self, z, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self.grad_la_y(z)
        A_batch = self.A[idx]
        c_batch = self.c[idx]
        return self._grad_la_y_batch(A_batch, c_batch, z)

    def grad_x_stochastic_penalty(self, z, rho, eps, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self._grad_x_batch_penalty(self.A, self.b, self.c, z, rho, eps)
        A_batch = self.A[idx]
        b_batch = self.b[idx]
        c_batch = self.c[idx]
        return self._grad_x_batch_penalty(A_batch, b_batch, c_batch, z, rho, eps)

    def grad_y_stochastic_penalty(self, z, rho, eps, batch_size=None, idx=None):
        if idx is None:
            idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self._grad_y_batch_penalty(self.A, self.b, self.c, z, rho, eps)
        A_batch = self.A[idx]
        b_batch = self.b[idx]
        c_batch = self.c[idx]
        return self._grad_y_batch_penalty(A_batch, b_batch, c_batch, z, rho, eps)

    @staticmethod
    def _gda_field(A, b, c, z, lamb, gamma, beta):
        m_batch, dim = A.shape
        x = z[:dim]
        y = z[dim]
        Ax = A @ x
        exp_bAx = np.exp(b * Ax)
        p1 = 1 / (1 + exp_bAx)
        exp_c_multt_y = np.exp(c * Ax * y)
        p2 = 1 / (1 + exp_c_multt_y)

        gx = -A.T @ (p1 * b / m_batch) + A.T @ (c * y * p2 * beta / m_batch) + 2 * lamb * x
        gy = np.sum(Ax * c * beta * p2 / m_batch) - 2 * gamma * y
        g = np.zeros((dim + 1, 1))
        g[:dim] = gx
        g[dim] = -gy
        return g

    def GDA_field(self, z):
        return self._gda_field(self.A, self.b, self.c, z, self.lamb, self.gamma, self.beta)

    def GDA_field_penalty(self, z, rho, eps=None):
        eps_val = self.penalty_eps if eps is None else eps
        Ax = self.A @ z[: self.dim]
        L_adv = np.log1p(np.exp(self.c * Ax * z[self.dim])).mean()
        hinge = max(L_adv - eps_val, 0.0)
        adv_weight = self.beta + rho * hinge
        return self._gda_field(self.A, self.b, self.c, z, self.lamb, self.gamma, adv_weight)

    def stochastic_GDA_field(self, z, batch_size=None):
        if batch_size is None or batch_size >= self.m:
            return self.GDA_field(z)
        idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            return self.GDA_field(z)
        A_batch = self.A[idx]
        b_batch = self.b[idx]
        c_batch = self.c[idx]
        return self._gda_field(A_batch, b_batch, c_batch, z, self.lamb, self.gamma, self.beta)

    def stochastic_GDA_field_penalty(self, z, rho, batch_size=None, eps=None):
        eps_val = self.penalty_eps if eps is None else eps
        idx = self._sample_indices(batch_size)
        if isinstance(idx, slice):
            Ax = self.A @ z[: self.dim]
            L_adv = np.log1p(np.exp(self.c * Ax * z[self.dim])).mean()
            hinge = max(L_adv - eps_val, 0.0)
            adv_weight = self.beta + rho * hinge
            return self._gda_field(self.A, self.b, self.c, z, self.lamb, self.gamma, adv_weight)

        A_batch = self.A[idx]
        b_batch = self.b[idx]
        c_batch = self.c[idx]
        Ax = A_batch @ z[: self.dim]
        L_adv = np.log1p(np.exp(c_batch * Ax * z[self.dim])).mean()
        hinge = max(L_adv - eps_val, 0.0)
        adv_weight = self.beta + rho * hinge
        return self._gda_field(A_batch, b_batch, c_batch, z, self.lamb, self.gamma, adv_weight)

    def _sample_indices(self, batch_size):
        if batch_size is None or batch_size >= self.m:
            return slice(None)
        if self._idx_ptr + batch_size > self.m:
            self._perm = np.random.permutation(self.m)
            self._idx_ptr = 0
        idx = self._perm[self._idx_ptr : self._idx_ptr + batch_size]
        self._idx_ptr += batch_size
        return idx

    def Jacobian_GDA_field(self, z):
        x = z[:self.dim]
        y = z[self.dim]
        multt = self.A @ x
        
        # Stability fix if needed, but sticking to original logic logic for reproduction
        p1 = 1 / (1 + np.exp(self.b * multt))
        p2 = 1 / (1 + np.exp(self.c * (multt) * y))
        
        Hxx = self.A.T @ (p1 * (1 - p1) / self.m * self.A) - self.A.T @ (p2 * (1 - p2) * self.beta / self.m * y**2 * self.A) + 2 * self.lamb * np.eye(self.dim)
        Hyy = -np.sum(p2 * (1 - p2) / self.m * self.beta * (self.A @ x)**2) - 2 * self.gamma
        Hxy = -self.A.T @ ((self.A @ x) * y * p2 * (1 - p2) * self.beta / self.m) + self.A.T @ (p2 * self.beta / self.m * self.c)
        
        H = np.eye(self.d)
        H[:self.dim, :self.dim] = Hxx
        H[self.dim, self.dim] = -Hyy
        H[:self.dim, self.dim] = Hxy.reshape(-1)
        H[self.dim, :self.dim] = -Hxy.reshape(-1)
        return H
