from network import *
# To do : extend to multiclass classification, for multi-protected attributes

class CLF(object):
    def __init__(self, train_loader, x_test, y_test, a_test):
      self.train_loader = train_loader
      self.x_test = x_test
      self.y_test = y_test
      self.a_test = a_test
      self.num_z = int(torch.max(self.a_test).item()) + 1
      self.softmax_func = nn.Softmax(dim=1)
      self.logs = {'all_acc': [], 'all_loss': [], 'ind_loss':[], 'ind_dist_boundary':[], 'ind_acc':[]}
      for i in range(self.num_z):
          self.logs['acc_{}'.format(i)] = []
          self.logs['loss_{}'.format(i)] = []
          self.logs['dist_boundary_{}'.format(i)] = []

    def write_logs(self, model):
        model.eval()
        loss_func = nn.BCELoss()
        ind_loss_func = nn.BCELoss(reduction='none')
        y_pred = model(self.x_test)
        ind_dist_boundary = y_pred*(1-y_pred)
        self.logs['ind_dist_boundary'].append(ind_dist_boundary.detach().numpy())
        y_true = self.y_test
        y_hard_pred = (y_pred > 0.5).float()
        acc = torch.mean(torch.Tensor.double(y_hard_pred == y_true)).item()
        loss = loss_func(y_pred, y_true).item()
        ind_loss = ind_loss_func(y_pred, y_true).detach().numpy()
        ind_acc = (y_hard_pred.detach().numpy() == y_true.detach().numpy()).astype(int)
        self.logs['ind_loss'].append(ind_loss)
        self.logs['ind_acc'].append(ind_acc)
        self.logs['all_acc'].append(acc)
        self.logs['all_loss'].append(loss)

        for i in range(2):
            y_group_pred = y_pred[self.a_test == i]
            self.logs['dist_boundary_{}'.format(i)].append(copy.deepcopy(torch.mean(y_group_pred * (1 - y_group_pred)).item()))
            y_group_true = self.y_test[self.a_test == i]
            group_loss = loss_func(y_group_pred, y_group_true)
            y_hard_group_pred = (y_group_pred > 0.5).float()
            acc = torch.mean(torch.Tensor.double(y_hard_group_pred == y_group_true)).item()
            self.logs['acc_{}'.format(i)].append(acc)
            self.logs['loss_{}'.format(i)].append(group_loss.item())

    def fit(self, options):
        """
        train a neural network model
        """
        torch.manual_seed(0)
        #model = MLPNet(options)
        if options['model'] =='MLP':
            model = MLPNet(options)
        else:
            model = LRNet(options)

        lambda_ = options.get('lambda', 1e-3) # regularization parameters
        loss_func = nn.BCELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(), lr=options['lr'])
        for epoch in range(options['epochs']):
            for inputs, targets in self.train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                ones_vec = torch.Tensor(np.ones(len(outputs))).reshape(-1,1)
                zeros_vec = torch.Tensor(np.zeros(len(outputs))).reshape(-1,1)
                if options['label']=='soft-labels':
                    loss_1 = loss_func(outputs, ones_vec)
                    weight_1 = targets[:,1].reshape(loss_1.shape)
                    loss_0 = loss_func(outputs, zeros_vec)
                    weight_0 = targets[:,0].reshape(loss_0.shape)
                    clf_loss = torch.mean(loss_1*weight_1 + loss_0 * weight_0)

                else:
                    targets = targets.reshape(-1,1)
                    clf_loss = torch.mean(loss_func(outputs, targets))

                para_norm = torch.norm(torch.stack([torch.norm(p) for p in model.parameters()]))
                total_loss = clf_loss + lambda_ * para_norm
                total_loss.backward()
                optimizer.step()
             # Save the model and write logs

            self.write_logs(model)
        self.model = model