from classifier import *

class FairModel(CLF):
    def __init__(self, train_loader, x_test, y_test, a_test):
        super().__init__(train_loader, x_test, y_test, a_test)
        self.logs['loss_by_clipping_0'] = []
        self.logs['loss_by_clipping_1'] = []

    def fit(self, options):
        """
        Implement fair model by aligning the R^clip_a and R^{noise}_a over different groups
        as presented in the paper
        """
        if options['model'] == 'MLP':
            torch.manual_seed(options['random_state_1'])
            model = MLPNet(options)

        if options['cuda']:
            model.to('cuda')
        if options['init_model']:
            model = copy.deepcopy(options['init_model'])
        self.init_model = copy.deepcopy(model)
        torch.manual_seed(options['random_state_2'])
        loss_func = nn.BCELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(), lr=options['lr'])
        self.z = options['num_z']
        self.C = options['C']
        self.gamma_1 = options['gamma_1']
        self.gamma_2 = options['gamma_2']
        self.sigma = options['sigma']
        if options['DP']:
            privacy_engine = PrivacyEngine(
                model,
                batch_size=options['bs'],
                sample_size=len(self.train_loader.dataset),
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier=self.sigma,
                max_grad_norm=self.C
            )
            privacy_engine.attach(optimizer)

        for epoch in range(options['epochs']):
            if epoch % 50 == 0:
                print('Done for {} epochs'.format(epoch))
            loss_group_0_list = []  # loss due to gradient clipping in group 0
            loss_group_1_list = []  # loss due to gradient clipping in group 1

            for i, (images, labels) in enumerate(self.train_loader):
                model.zero_grad()
                optimizer.zero_grad()
                if options['DP']:
                    clear_backprops(model)
                group_labels = labels[:, 1]  # I store the group label information by appending to labels
                labels = labels[:, 0]
                labels = labels.reshape(-1, 1)
                if torch.unique(group_labels).shape[0] > 1:
                    # make sure that for evey mini-batch each group should have at least 1 samples of all groups
                    outputs = model(images)
                    loss_list = loss_func(outputs, labels)
                    grad_group_dict = {}
                    num_samples_per_group_list = []
                    # Step 1: Compute gradient of loss function over model's parameter of each group of samples
                    for i in range(self.num_z):
                        model.zero_grad()
                        clear_backprops(model)
                        group_loss = torch.mean(loss_list[group_labels == i])
                        group_loss_grad = autograd.grad(group_loss, model.parameters(), create_graph=True)
                        grad_group_dict[i] = torch.cat([torch.flatten(p) for p in group_loss_grad]).view(-1)
                        num_samples_per_group_list.append(
                            float(len(loss_list[group_labels == i])))  # number of samples of group-i in this mini-batch

                    flatten_grad = grad_group_dict[0] * num_samples_per_group_list[0] + grad_group_dict[1] * \
                                                                                        num_samples_per_group_list[1]
                    flatten_grad = flatten_grad / np.sum(num_samples_per_group_list)  # g = g_0 *p_0 + g_1 * p_1
                    # Step 2: Compute flattenn average of clipped gradient
                    flatten_clipped_grad = torch.zeros(flatten_grad.shape)
                    for ind_loss in loss_list:
                        model.zero_grad()
                        clear_backprops(model)
                        ind_loss_grad = autograd.grad(ind_loss, model.parameters(), create_graph=True)
                        ind_grad_norm = torch.norm(torch.stack([torch.norm(p) for p in ind_loss_grad]))
                        flatten_ind_clipped_grad = torch.cat([torch.flatten(p) for p in ind_loss_grad]).view(
                            -1) * torch.min(torch.tensor([1.0]), self.C / ind_grad_norm)
                        flatten_clipped_grad += flatten_ind_clipped_grad / np.sum(num_samples_per_group_list)

                    loss_group_0 = torch.dot(grad_group_dict[0], flatten_clipped_grad - flatten_grad)
                    loss_group_1 = torch.dot(grad_group_dict[1], flatten_clipped_grad - flatten_grad)
                    loss_group_0_list.append(loss_group_0.item())
                    loss_group_1_list.append(loss_group_1.item())
                    reg_loss_1 = torch.abs(loss_group_0 - loss_group_1)
                    model.zero_grad()
                    clear_backprops(model)

                    dist_boundary_dict = {}  # store distance to  the decison boundary of each group
                    for i in range(self.z):
                        dist_boundary_dict[i] = torch.mean(
                            outputs[group_labels == i] * (1 - outputs[group_labels == i]))

                    reg_loss_2 = torch.abs(dist_boundary_dict[0] - dist_boundary_dict[1])
                    clf_loss = torch.mean(loss_list)
                    total_loss = clf_loss + self.gamma_1 * reg_loss_1 + self.gamma_2 * reg_loss_2
                    total_loss.backward(retain_graph=True)
                    if options['add_noise']:
                        for name, p in model.named_parameters():
                            p.grad += torch.FloatTensor(p.grad.shape).normal_(0, self.sigma)
                    optimizer.step()

            self.logs['loss_by_clipping_0'].append(np.mean(loss_group_0_list))
            self.logs['loss_by_clipping_1'].append(np.mean(loss_group_1_list))
            if options['write_logs']:
                self.write_logs(model)
            self.model = model
