from network import *
from utils import *


def compute_grad( model):
    grad_stack = torch.stack([torch.norm(p.grad) for p in model.parameters()])
    return torch.norm(grad_stack).item()


def clear_backprops(model: nn.Module) -> None:
    """Delete layer.backprops_list in every layer."""
    for layer in model.modules():
        if hasattr(layer, "backprops_list"):
            del layer.backprops_list

class CLF(object):

    def __init__(self, train_loader, x_test, y_test, a_test):
        self.train_loader = train_loader
        self.x_test = x_test
        self.y_test = y_test
        self.a_test = a_test
        self.num_z = int(torch.max(self.a_test).item()) + 1
        self.softmax_func = nn.Softmax(dim=1)
        self.logs = {'all_acc': [], 'all_loss': [], 'ind_loss':[]}
        for i in range(self.num_z):
            self.logs['acc_{}'.format(i)] = []
            self.logs['loss_{}'.format(i)] = []
            self.logs['trace_hessian_{}'.format(i)] = []
            self.logs['grad_norm_{}'.format(i)] = []
            self.logs['dist_boundary_{}'.format(i)] = []

    def write_logs(self, model):
        """
        Keep track of accuracy/loss/entropy predictions/models during training
        Keep tracking of different metrics, models predictions help to understand the model deeply.
        """
        model.eval()
        loss_func = nn.BCELoss()
        ind_loss_func = nn.BCELoss(reduction ='none')
        y_pred = model(self.x_test)
        y_true = self.y_test
        y_hard_pred = (y_pred > 0.5).float()
        acc = torch.mean(torch.Tensor.double(y_hard_pred == y_true)).item()
        loss = ind_loss_func(y_pred, y_true).detach().cpu().numpy()
        self.logs['all_acc'].append(acc)
        self.logs['all_loss'].append(np.mean(loss))
        self.logs['ind_loss'].append(loss)

        for i in range(self.num_z):
            model.zero_grad()
            y_group_pred = model(self.x_test[self.a_test == i])
            self.logs['dist_boundary_{}'.format(i)].append(copy.deepcopy(torch.mean(y_group_pred * (1 - y_group_pred)).item()))
            y_group_true = self.y_test[self.a_test == i]
            group_loss = loss_func(y_group_pred, y_group_true)
            hessian_comp2 = hessian(copy.deepcopy(model), loss_func,
                                    data=(self.x_test[self.a_test == i], self.y_test[self.a_test == i]), cuda=False)
            trace_ind = hessian_comp2.trace()
            self.logs['trace_hessian_{}'.format(i)].append(trace_ind[-1])
            y_hard_group_pred = (y_group_pred > 0.5).float()
            acc = torch.mean(torch.Tensor.double(y_hard_group_pred == y_group_true)).item()
            self.logs['acc_{}'.format(i)].append(acc)
            self.logs['loss_{}'.format(i)].append(group_loss.item())
            model.zero_grad()
            group_loss.backward(retain_graph=True)
            self.logs['grad_norm_{}'.format(i)].append(copy.deepcopy(compute_grad(model)))

    def fit(self, options):
        """
        Train a classifier with three mode options:
        + Train clfular (no privacy, no clipping)
        + Train privately ( clipping + adding noise to gradients)
        + Train by clipping gradient ( only clipping gradients)
        We have three mode options to debug the effects of gradient clipping to fairness/accuracy/
        """
        torch.manual_seed(options['random_state_1'])
        if options['model'] == 'MLP':
            model = MLPNet(options)
        else:
            model = LRNet(options)
        if options['cuda']:
            model.to('cuda')
        if options['init_model']:
            model = copy.deepcopy(options['init_model'])
        self.init_model = copy.deepcopy(model)
        torch.manual_seed(options['random_state_2'])
        criterion = nn.BCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=options['lr'])
        self.z = options['num_z']
        self.C = options['C']
        self.sigma = options['sigma']
        if options['DP']:
            privacy_engine = PrivacyEngine(
                model,
                batch_size=options['bs'],
                sample_size=len(self.train_loader.dataset),
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier=self.sigma,
                max_grad_norm=self.C
            )
            privacy_engine.attach(optimizer)

        for _ in range(options['epochs']):
            model.train()
            for i, (images, labels) in enumerate(self.train_loader):
                model.zero_grad()
                optimizer.zero_grad()
                if options['DP']:
                    clear_backprops(model)
                labels = labels[:, 0]
                labels = labels.reshape(-1, 1)
                if options['cuda']:
                    labels = labels.to('cuda')
                    images = images.to('cuda')
                    model.to('cuda')

                outputs = model(images)
                clf_loss = criterion(outputs, labels)
                clf_loss.backward()
                if options['add_noise']:
                    if not options['cuda']:
                        for name, p in model.named_parameters():
                            p.grad += torch.FloatTensor(p.grad.shape).normal_(0, self.sigma)
                    else:
                        for name, p in model.named_parameters():
                            p.grad += torch.cuda.FloatTensor(p.grad.shape).normal_(0, self.sigma)
                optimizer.step()

            if options['write_logs']:
                self.write_logs(model)
            self.model = model