import numpy as np
from scipy.special import erf
from scipy.optimize import minimize

# compute the gamma values for the equations
def _gamma_11(tau, b, m_y, m_a, coef):
    val = np.exp(-(m_y + tau * m_a + b) ** 2 / 2 / (m_y + tau ** 2 * m_a))
    val *= coef
    return val

def _gamma_10(tau, b, m_y, m_a, coef):
    val = np.exp(-(m_y - tau * m_a + b) ** 2 / 2 / (m_y + tau ** 2 * m_a))
    val *= coef
    return val

def _gamma_00(tau, b, m_y, m_a, coef):
    val = np.exp(-(m_y + tau * m_a - b) ** 2 / 2 / (m_y + tau ** 2 * m_a))
    val *= coef
    return val

def _gamma_01(tau, b, m_y, m_a, coef):
    val = np.exp(-(m_y - tau * m_a - b) ** 2 / 2 / (m_y + tau ** 2 * m_a))
    val *= coef
    return val

# compute the accuracy of the corresponding group
def _acc_11(tau, b, m_y, m_a):
    return (1 + erf((m_y + tau*m_a + b)/np.sqrt(2 * (m_y + tau**2*m_a)))) / 2
def _acc_10(tau, b, m_y, m_a):
    return (1 + erf((m_y - tau*m_a + b)/np.sqrt(2 * (m_y + tau**2*m_a)))) / 2
def _acc_00(tau, b, m_y, m_a):
    return (1 + erf((m_y + tau*m_a - b)/np.sqrt(2 * (m_y + tau**2*m_a)))) / 2
def _acc_01(tau, b, m_y, m_a):
    return (1 + erf((m_y - tau*m_a - b)/np.sqrt(2 * (m_y + tau**2*m_a)))) / 2

# compute the adjusted accuracy
def acc_all(tau, b, m_y, m_a):
    vars = (tau, b, m_y, m_a)
    return (_acc_11(*vars) + _acc_10(*vars) + _acc_00(*vars) + _acc_01(*vars)) / 4

# equation 1 of the system
def equation_1(tau, b, m_y, m_a, coef_id, coefs):
    coef_11, coef_00, coef_01, coef_10 = coefs
    gamma_11 = _gamma_11(tau, b, m_y, m_a, coef_11[coef_id])
    gamma_00 = _gamma_00(tau, b, m_y, m_a, coef_00[coef_id])
    gamma_01 = _gamma_01(tau, b, m_y, m_a, coef_01[coef_id])
    gamma_10 = _gamma_10(tau, b, m_y, m_a, coef_10[coef_id])

    N = gamma_11 + gamma_00 - gamma_10 - gamma_01
    D = gamma_11 + gamma_00 + gamma_10 + gamma_01

    return tau - N/D

# equation 2 of the system
def equation_2(tau, b, m_y, m_a, coef_id, coefs):
    coef_11, coef_00, coef_01, coef_10 = coefs
    gamma_11 = _gamma_11(tau, b, m_y, m_a, coef_11[coef_id])
    gamma_00 = _gamma_00(tau, b, m_y, m_a, coef_00[coef_id])
    gamma_01 = _gamma_01(tau, b, m_y, m_a, coef_01[coef_id])
    gamma_10 = _gamma_10(tau, b, m_y, m_a, coef_10[coef_id])

    return gamma_11 + gamma_10 - gamma_00 - gamma_01 

#### Here `Acc` is a global list of two accuracy values we use to interpolate ####
# equation 3 of the system -- to match the fist accuracy
def equation_3(tau, b, m_y, m_a, Acc):
    return Acc[0] - acc_all(tau, b, m_y, m_a)

# equation 3 of the system -- to match the second accuracy
def equation_4(tau, b, m_y, m_a, Acc):
    return Acc[1] - acc_all(tau, b, m_y, m_a)

## The class that only solves equation 1 and equation 2
## That is, this is used to compute tau and b AFTER m_y,m_a are available
class System():
    
    def __init__(self, alphas, beta, m_y, m_a, precision = 10):
        
        # set the precision for the correlation ratios
        self.precision = precision
        self.alphas = [round(a, self.precision) for a in alphas]
        self.beta = round(beta, self.precision)
        self.m_y = m_y
        self.m_a = m_a
        
        # Pre-compute commonly used values to ensure consistency
        self._alpha1_beta = round(self.alphas[1] * self.beta, self.precision)
        self._1malpha0_1mbeta = round((1-self.alphas[0]) * (1-self.beta), self.precision)
        self._alpha0_1mbeta = round(self.alphas[0] * (1-self.beta), self.precision)
        self._1malpha1_beta = round((1-self.alphas[1]) * self.beta, self.precision)
        
    def equation_1(self, tau, b):
        gamma_11 = self._gamma_11(tau, b)
        gamma_00 = self._gamma_00(tau, b)
        gamma_01 = self._gamma_01(tau, b)
        gamma_10 = self._gamma_10(tau, b)
        
        N = gamma_11 + gamma_00 - gamma_10 - gamma_01
        D = gamma_11 + gamma_00 + gamma_10 + gamma_01
        
        return tau - N/D
    
    def equation_2(self, tau, b):
        gamma_11 = self._gamma_11(tau, b)
        gamma_00 = self._gamma_00(tau, b)
        gamma_01 = self._gamma_01(tau, b)
        gamma_10 = self._gamma_10(tau, b)
        
        return gamma_11 + gamma_10 - gamma_00 - gamma_01 
        
    def _gamma_11(self, tau, b):
        val = np.exp(-(self.m_y + tau * self.m_a + b) ** 2 / 2 / (self.m_y + tau ** 2 * self.m_a))
        val *= self._alpha1_beta
        return val
        
    def _gamma_10(self, tau, b):
        val = np.exp(-(self.m_y - tau * self.m_a + b) ** 2 / 2 /(self.m_y + tau ** 2 * self.m_a))
        val *= self._1malpha1_beta
        return val
    
    def _gamma_00(self, tau, b):
        val = np.exp(-(self.m_y + tau * self.m_a - b) ** 2 / 2 / (self.m_y + tau ** 2 * self.m_a))
        val *= self._alpha0_1mbeta
        return val
        
    def _gamma_01(self, tau, b):
        val = np.exp(-(self.m_y - tau * self.m_a - b) ** 2 / 2 / (self.m_y + tau ** 2 * self.m_a))
        val *= self._1malpha0_1mbeta
        return val


def estimate(zeta_list, acc_array, shift_type, precision=10):
    '''
    zeta_list: an ascending list of zeta values from 0.5 to 0.999 [0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999]
    acc_array: an array of accuracy with a shape of (num_zeta, num_seeds)
    shift_type: 0 for spurious correlation; 1 for under-representation; 2 for class imbalance
    '''

    # The error function to minize
    def H(x):
        tau,b,taup,bp,m_y,m_a = x
        
        val = equation_1(tau, b, m_y, m_a, 0, coefs)**2 \
            + equation_2(tau, b, m_y, m_a, 0, coefs)**2 \
            + 10*equation_3(tau, b, m_y, m_a, Acc)**2 \
            + equation_1(taup, bp, m_y, m_a, 1, coefs)**2 \
            + equation_2(taup, bp, m_y, m_a, 1, coefs)**2 \
            + 10*equation_4(taup, bp, m_y, m_a, Acc)**2 ## Accuracy of the second point
        return val


    def h(x):
        tau,b = x
        val = system.equation_1(tau, b) ** 2 + system.equation_2(tau, b) ** 2
        return val

    ## Spurious Correlation
    if shift_type == 0:
        Alpha = [ 
            [0.999, 0.999], # (alpha_0, alpha_1) of the first interpolation point
            [0.9, 0.9], # (alpha_0, alpha_1) of the second interpolation point
        ]
        
        # For spurious correlation, beta = 0.5 to keep class balanced
        Beta = [0.5, 0.5]
        
    ## Under Representation
    if shift_type == 1:
        Alpha = [ 
            [0.999, 1-0.999], # (alpha_0, alpha_1) of the first interpolation point
            [0.9, 1-0.9], # (alpha_0, alpha_1) of the second interpolation point
        ]
        # For UR, beta = 0.5 to keep class balanced
        Beta = [0.5, 0.5]
        
    ## Class Imbalance
    elif shift_type == 2:
        # For CI, alpha_0 = P(A=-1|Y=-1), alpha_1 = P(A = 1|Y = 1) are both balanced
        Alpha = [ 
            [0.5, 0.5], # (alpha_0, alpha_1) of the first interpolation point
            [0.5, 0.5], # (alpha_0, alpha_1) of the second interpolation point
        ]
        Beta = [0.999, 0.9]

    # Pre-compute commonly used values to ensure consistency
    coef_11 = [round(Alpha[0][1] * Beta[0], precision),
            round(Alpha[1][1] * Beta[1], precision)]
    coef_10 = [round((1-Alpha[0][1]) * Beta[0], precision),
            round((1-Alpha[1][1]) * Beta[1], precision)]
    coef_00 = [round(Alpha[0][0] * (1-Beta[0]), precision),
            round(Alpha[1][0] * (1-Beta[1]), precision)]
    coef_01 = [round((1-Alpha[0][0]) * (1-Beta[0]), precision),
            round((1-Alpha[1][0]) * (1-Beta[1]), precision)]

    coefs = coef_11, coef_00, coef_01, coef_10
    # use the accuracy of zeta=0.9 and zeta=0.999 for estimation
    Acc = [acc_array.mean(1)[-1], acc_array.mean(1)[-4]]
        
    constraint1=lambda x:x[0]
    constraint2=lambda x:x[1]
    constraint3=lambda x:x[2]
    constraint4=lambda x:x[3]
    constraint5=lambda x:x[4]
    constraint6=lambda x:x[5]
    constraint7=lambda x:1-x[0]
    constraint8=lambda x:1-x[2]


    con1 = {'type': 'ineq', 'fun': constraint1}
    con2 = {'type': 'ineq', 'fun': constraint2}
    con3 = {'type': 'ineq', 'fun': constraint3}
    con4 = {'type': 'ineq', 'fun': constraint4}
    con5 = {'type': 'ineq', 'fun': constraint5}
    con6 = {'type': 'ineq', 'fun': constraint6}
    con7 = {'type': 'ineq', 'fun': constraint7}
    con8 = {'type': 'ineq', 'fun': constraint8}

    # initial_guess = [0.96240195, 0., 0.29020042, 0., 2.99, 2.91]
    initial_guess = [0., 0., 0., 0., 2., 2.]

    result = minimize(H, initial_guess, 
                    constraints=[con1, con2, con3, con4, con5, con6, con7, con8], method='SLSQP')

    tau,b,taup,bp,m_y,m_a = result.x

    initial_guesses = [
        [0.3,0.3],
    ]

    constraint1=lambda x:x[0]
    constraint2=lambda x:1-x[0]
    con1 = {'type': 'ineq', 'fun': constraint1}
    con2 = {'type': 'ineq', 'fun': constraint2}


    Acc_est = []
    # estimate accuracy for zeta >= 0.9
    for zeta in zeta_list[:-4]:
        
        if shift_type == 0:
            alphas = [zeta, zeta]
            beta = .5
        elif shift_type == 1:
            alphas = [zeta, 1 - zeta]
            beta = .5
        elif shift_type == 2:
            alphas = [.5, .5]
            beta = zeta

        system = System(alphas, beta, m_y, m_a)

        success = False
        for guess in initial_guesses:
            result = minimize(h, guess, constraints=[con1, con2], method='SLSQP')
            if result.success:
                success = True
                break
        if not success:
            print("FAILED at alpha: [{:.2f}, {:.2f}], beta: [{:.2f}]".format(alphas[0], alphas[1], beta))
            
        Acc_est.append(acc_all(result.x[0], result.x[1], m_y, m_a))
    
    return Acc_est