from utils import *
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold


class Teacher(object):
    def __init__(self, private_x, private_y, options={'K': 100, 'data_seed': 1}):
        """
        private_x, private_y are two numpy objectives
        """
        self.K = options.get('K')  # number of teachers
        data = np.c_[private_x, private_y]
        np.random.seed(options['data_seed'])
        np.random.shuffle(data)
        skf = StratifiedKFold(n_splits=options['K'], random_state=options['data_seed'], shuffle=True)
        self.teachers = []
        for train_index, test_index in skf.split(private_x, private_y):
          x, y = private_x[test_index,:], private_y[test_index]
          logreg = LogisticRegression().fit(x, y)
          self.teachers.append(logreg)


    def agg_voting(self, public_x):
        """
        Perform aggregate label prediction from teachers by majority voting. No privacy involved
        """
        y_pred_list = [self.teachers[i].predict(public_x).reshape(-1, ) for i in  range(self.K)]
        y_pred_arr = np.asarray(y_pred_list).T
        y_pred_arr = np.asarray([self.K - np.sum(y_pred_arr, axis=1), np.sum(y_pred_arr, axis=1)]).T

        return y_pred_arr

    def private_prediction(self, public_x, noise_std, seed):
        """
        Perform private majority voting 
        """
        np.random.seed(seed)
        y_pred_arr = self.agg_voting(public_x)
        noisy_y_pred_arr = y_pred_arr + np.random.normal(0, noise_std, y_pred_arr.shape)
        y_vote = noisy_y_pred_arr[:, 1] >= noisy_y_pred_arr[:, 0]

        return y_vote.astype(int)

    def private_soft_proba(self, public_x, noise_std, seed):
        """
        Return a private soft-labels .eg noisy probability of getting label 1 or label 0.
        """
        np.random.seed(seed)
        y_pred_arr = self.agg_voting(public_x)
        noisy_y_pred_arr = y_pred_arr + np.random.normal(0, noise_std, y_pred_arr.shape)
        noisy_y_pred_arr = np.clip(noisy_y_pred_arr, 1e-10, np.inf) # make sure that the soft proba are non-negative
        sum_rows = noisy_y_pred_arr.sum(axis=1)
        noisy_y_pred_arr = noisy_y_pred_arr / sum_rows[:, np.newaxis]

        return noisy_y_pred_arr