# PROBABLY WILL NOT USE THIS VERSIOn
import torch, copy
from torch import nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.linear_model import LogisticRegression


class MLPNet(nn.Module):
    """
    Implement a MLP with  single hidden layer. The choice of activation
    function can be passed as argument when init the network
    """

    def __init__(self, options={'num_feats': 20, 'activation': 'relu', 'width': 20}):
        super(MLPNet, self).__init__()
        if options['activation'] == 'relu':
            self.act_func = nn.ReLU()
        if options['activation'] == 'tanh':
            self.act_func = nn.Tanh()
        else:
            self.act_func = nn.Sigmoid()
        self.sigmoid = nn.Sigmoid()
        self.input_layer = nn.Linear(options['num_feats'], int(options['num_feats'] / 2))
        self.o_layer = nn.Linear(int(options['num_feats'] / 2), 1)

    def forward(self, x):
        output = self.act_func(self.input_layer(x))
        output = self.o_layer(output)
        return self.sigmoid(output)


###1.2 IMPLEMENT  PRIVATE/NON-PRIVATE CLASSIFIER TRAINING  ########
class CLF(object):
    def __init__(self, train_loader, x_test, y_test, a_test):
        self.train_loader = train_loader
        self.x_test = x_test
        self.y_test = y_test
        self.a_test = a_test
        self.num_z = int(torch.max(self.a_test).item()) + 1
        self.softmax_func = nn.Softmax(dim=1)
        self.logs = {'all_acc': [], 'all_loss': [], 'ind_loss': []}
        for i in range(self.num_z):
            self.logs['acc_{}'.format(i)] = []
            self.logs['loss_{}'.format(i)] = []
            self.logs['dist_boundary_{}'.format(i)] = []

    def write_logs(self, model):
        model.eval()
        loss_func = nn.BCELoss()
        ind_loss_func = nn.BCELoss(reduction='none')
        y_pred = model(self.x_test)
        y_true = self.y_test
        y_hard_pred = (y_pred > 0.5).float()
        acc = torch.mean(torch.Tensor.double(y_hard_pred == y_true)).item()
        loss = loss_func(y_pred, y_true).item()
        ind_loss = ind_loss_func(y_pred, y_true).detach().numpy()
        self.logs['ind_loss'].append(ind_loss)
        self.logs['all_acc'].append(acc)
        self.logs['all_loss'].append(loss)

        for i in range(2):
            y_group_pred = y_pred[self.a_test == i]
            self.logs['dist_boundary_{}'.format(i)].append(
                copy.deepcopy(torch.mean(y_group_pred * (1 - y_group_pred)).item()))
            y_group_true = self.y_test[self.a_test == i]
            group_loss = loss_func(y_group_pred, y_group_true)
            y_hard_group_pred = (y_group_pred > 0.5).float()
            acc = torch.mean(torch.Tensor.double(y_hard_group_pred == y_group_true)).item()
            self.logs['acc_{}'.format(i)].append(acc)
            self.logs['loss_{}'.format(i)].append(group_loss.item())

    def fit(self, options):
        """
        train a neural network model
        """
        torch.manual_seed(0)
        model = MLPNet(options)
        if options['label'] == 'soft-prob':
            loss_func = nn.L1Loss()
        else:
            loss_func = nn.BECLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=options['lr'])
        for epoch in range(options['epochs']):
            for inputs, targets in self.train_loader:
                model.zero_grad()
                optimizer.zero_grad()
                outputs = model(inputs)
                if options['label'] == 'soft-prob':
                    main_outputs = outputs.reshape(-1, 1)
                    aux_output = 1 - outputs
                    outputs = torch.cat((main_outputs, aux_output), 1)
                clf_loss = loss_func(outputs, targets)
                clf_loss.backward()
                optimizer.step()
                # Save the model and write logs

            self.write_logs(model)
        self.model = model


class Teacher(object):
    def __init__(self, private_x, private_y, options={'K': 100, 'data_seed': 1}):
        """
        private_x, private_y are two numpy objectives
        """
        self.K = options.get('K')  # number of teachers
        data = np.c_[private_x, private_y]
        np.random.seed(options['data_seed'])
        np.random.shuffle(data)
        split_data = np.array_split(data, self.K)
        self.teachers = []
        for i in range(self.K):
            x, y = split_data[i][:, :-1], split_data[i][:, -1]
            logreg = LogisticRegression().fit(x, y)
            self.teachers.append(logreg)

    def agg_voting(self, public_x):
        y_pred_list = [
            self.teachers[i].predict(public_x).reshape(-1, ) for i in
            range(self.K)]
        y_pred_arr = np.asarray(y_pred_list).T

        return y_pred_arr

    def private_prediction(self, public_x, noise_std, seed):
        ## BECAREFUL"
        np.random.seed(seed)
        y_pred_arr = self.agg_voting(public_x)
        y_pred_arr = np.asarray([np.sum(y_pred_arr, axis=1), self.K - np.sum(y_pred_arr, axis=1)]).T
        noisy_y_pred_arr = y_pred_arr + np.random.normal(0, noise_std, y_pred_arr.shape)
        y_vote = noisy_y_pred_arr[:, 0] >= noisy_y_pred_arr[:, 1]

        return y_vote

    def private_soft_proba(self, public_x, noise_std, seed):
        np.random.seed(seed)
        y_pred_arr = self.agg_voting(public_x)
        y_pred_arr = np.asarray([np.sum(y_pred_arr, axis=1), self.K - np.sum(y_pred_arr, axis=1)]).T
        noisy_y_pred_arr = y_pred_arr + np.random.normal(0, noise_std, y_pred_arr.shape)
        noisy_y_pred_arr = np.clip(noisy_y_pred_arr, 1e-10, np.inf)
        noisy_y_pred_arr = noisy_y_pred_arr / np.sum(noisy_y_pred_arr)

        return noisy_y_pred_arr


def get_partition(pd00, num_public):
    pd00 = shuffle(pd00)
    pd00['index'] = pd00.index
    public_idx = np.random.choice(range(len(pd00)), num_public, replace=False)
    public_pd = pd00[pd00['index'].isin(public_idx)]
    public_pd['partition'] = 'public'
    res_pd = pd00[~pd00['index'].isin(public_idx)]
    res_pd['partition'] = res_pd['index'].apply(lambda x: np.random.rand() >= 0.8).astype(int)
    pd00 = shuffle(pd.concat([res_pd, public_pd])).drop(columns=['index'])
    return pd00
