from utils import *

class MLPNet(nn.Module):
    """
    Implement a MLP with  single hidden layer. The choice of activation
    function can be passed as argument when init the network
    """

    def __init__(self, options={'num_feats': 20, 'activation': 'relu', 'width': 20}):
        super(MLPNet, self).__init__()
        if options['activation'] == 'relu':
            self.act_func = nn.ReLU()
        if options['activation'] == 'tanh':
            self.act_func = nn.Tanh( )
        else:
            self.act_func = nn.Sigmoid( )
        self.sigmoid = nn.Sigmoid()
        self.input_layer = nn.Linear(options['num_feats'], int(options['num_feats']/2))
        self.o_layer = nn.Linear(int(options['num_feats']/2) , 1)

    def forward(self, x):
        output = self.act_func(self.input_layer(x))
        output = self.o_layer(output)
        return self.sigmoid(output)

###1.2 IMPLEMENT  PRIVATE/NON-PRIVATE CLASSIFIER TRAINING  ########
class CLF(object):
    def __init__(self, train_loader, x_test, y_test, a_test):
      self.train_loader = train_loader
      self.x_test = x_test
      self.y_test = y_test
      self.a_test = a_test
      self.num_z = int(torch.max(self.a_test).item()) + 1
      self.softmax_func = nn.Softmax(dim=1)
      self.logs = {'all_acc': [], 'all_loss': [], 'ind_loss':[]}
      for i in range(self.num_z):
          self.logs['acc_{}'.format(i)] = []
          self.logs['loss_{}'.format(i)] = []
          self.logs['dist_boundary_{}'.format(i)] = []

    def write_logs(self, model):
        model.eval()
        loss_func = nn.BCELoss()
        ind_loss_func = nn.BCELoss(reduction='none')
        y_pred = model(self.x_test)
        y_true = self.y_test
        y_hard_pred = (y_pred > 0.5).float()
        acc = torch.mean(torch.Tensor.double(y_hard_pred == y_true)).item()
        loss = loss_func(y_pred, y_true).item()
        ind_loss = ind_loss_func(y_pred, y_true).detach().numpy()
        self.logs['ind_loss'].append(ind_loss)
        self.logs['all_acc'].append(acc)
        self.logs['all_loss'].append(loss)
        y_soft_pred = y_pred.detach().numpy()
        self.logs['ind_dist_boundary'] = y_soft_pred*(1-y_soft_pred)

        for i in range(2):
            y_group_pred = y_pred[self.a_test == i]
            self.logs['dist_boundary_{}'.format(i)].append(copy.deepcopy(torch.mean(y_group_pred * (1 - y_group_pred)).item()))
            y_group_true = self.y_test[self.a_test == i]
            group_loss = loss_func(y_group_pred, y_group_true)
            y_hard_group_pred = (y_group_pred > 0.5).float()
            acc = torch.mean(torch.Tensor.double(y_hard_group_pred == y_group_true)).item()
            self.logs['acc_{}'.format(i)].append(acc)
            self.logs['loss_{}'.format(i)].append(group_loss.item())

    def fit(self, options):
        """
        train a neural network model
        """
        torch.manual_seed(0)
        model = MLPNet(options)
        loss_func = nn.BCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=options['lr'])
        for epoch in range(options['epochs']):
            for inputs, targets in self.train_loader:
                targets = targets[:,0].reshape(-1,1)
                model.zero_grad()
                optimizer.zero_grad()
                outputs = model(inputs)
                clf_loss = loss_func(outputs, targets)
                clf_loss.backward()
                optimizer.step()
             # Save the model and write logs

            self.write_logs(model)
        self.model = model



class LRNet(nn.Module):
    def __init__(self, options={'num_feats': 20}):
        super(LRNet, self).__init__()
        self.i_layer = nn.Linear(options['num_feats'], 1)
        self._of = nn.Sigmoid()

    def forward(self, x):
        if len(x) == 0:
            return None
        o = self.i_layer(x)
        return self._of(o)

    def predict(self, x):
        pred,_ = self.forward(x)

        pred[pred >= 0.5] = 1
        pred[pred < 0.5] = 0
        return pred.detach().cpu().numpy()