import numpy as np
from scipy.optimize import minimize


def softmax(x):
    """
    Compute softmax values for each sets of scores in x.

    Parameters:
        x (numpy.ndarray): array containing m samples with n-dimensions (m,n)
    Returns:
        x_softmax (numpy.ndarray) softmaxed values for initial (m,n) array
    """
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1, keepdims=1)

def mse_t(t, *args):
    ## find optimal temperature with MSE loss function
    logit, label = args
    logit = logit/t
    n = np.sum(np.exp(logit),1)
    p = np.exp(logit)/n[:,None]
    mse = np.mean((p-label)**2)
    return mse


def ll_t(t, *args):
    ## find optimal temperature with Cross-Entropy loss function
    logit, label = args
    logit = logit/t
    n = np.sum(np.exp(logit),1)
    p = np.clip(np.exp(logit)/n[:,None],1e-20,1-1e-20)
    N = p.shape[0]
    ce = -np.sum(label*np.log(p))/N
    return ce


def mse_w(w, *args):
    ## find optimal weight coefficients with MSE loss function
    p0, p1, p2, label = args
    p = w[0]*p0+w[1]*p1+w[2]*p2
    p = p/np.sum(p,1)[:,None]
    mse = np.mean((p-label)**2)
    return mse


def ll_w(w, *args):
    ## find optimal weight coefficients with Cros-Entropy loss function
    p0, p1, p2, label = args
    p = (w[0]*p0+w[1]*p1+w[2]*p2)
    N = p.shape[0]
    ce = -np.sum(label*np.log(p))/N
    return ce


class mix_n_match():
    def __init__(self, n_class=2, temp=1, maxiter=50, solver="BFGS"):
        """
        Initialize class
        Params:
            temp (float): starting temperature, default 1
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.maxiter = maxiter
        self.solver = solver
        self.n_class = n_class
    @staticmethod
    def temperature_scaling(logit,label,loss):
        bnds = ((0.05, 5.0),)
        if loss == 'ce':
           t = minimize(ll_t, 1.0, args = (logit,label), method='L-BFGS-B', bounds=bnds, tol=1e-12)
        if loss == 'mse':
            t = minimize(mse_t, 1.0, args = (logit,label), method='L-BFGS-B', bounds=bnds, tol=1e-12)
        t = t.x
        return t
    @staticmethod
    def ensemble_scaling(logit, label,loss,t,n_class):
        p1 = np.exp(logit)/np.sum(np.exp(logit),1)[:,None]
        logit = logit/t
        p0 = np.exp(logit)/np.sum(np.exp(logit),1)[:,None]
        p2 = np.ones_like(p0)/n_class
        bnds_w = ((0.0, 1.0),(0.0, 1.0),(0.0, 1.0),)
        def my_constraint_fun(x): return np.sum(x)-1
        constraints = { "type":"eq", "fun":my_constraint_fun,}
        if loss == 'ce':
            w = minimize(ll_w, (1.0, 0.0, 0.0), args=(p0,p1,p2,label), method='SLSQP', constraints = constraints, bounds=bnds_w, tol=1e-12, options={'disp': True})
        if loss == 'mse':
            w = minimize(mse_w, (1.0, 0.0, 0.0), args=(p0,p1,p2,label), method='SLSQP', constraints = constraints, bounds=bnds_w, tol=1e-12, options={'disp': True})
        w = w.x
        return w

    def fit(self, logtis, label):
        '''
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            label: true labels.
        '''
        label =  np.eye(2)[label.astype(int)]
        t = self.temperature_scaling(logtis,label,loss='mse') # loss can change to 'ce'
        print("temperature = " +str(t))
        w = self.ensemble_scaling(logtis,label,'mse',t, self.n_class)
        print("weight = " +str(w))
        self.t = t
        self.w = w

    def predict_proba(self, logits, temp=None):
        """
        Scales logits based on the temperature and returns calibrated probabilities
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set use temperatures find by model or previously set.
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        p1 = softmax(logits)
        logits = logits/self.t
        p0 = softmax(logits)
        p2 = np.ones_like(p0)/self.n_class
        return self.w[0]*p0 + self.w[1]*p1 +self.w[2]*p2