from classifier import *
from torchdp import PrivacyEngine
from torch import autograd
from torchdp import PerSampleGradientClipper

class PrivCLF(CLF):
    def __init__(self, train_loader, x_test, y_test, a_test):
        super(PrivCLF, self).__init__(train_loader, x_test, y_test, a_test)


    def fit(self, options):
        """
        train a neural network model
        """
        torch.manual_seed(0)
        #model = MLPNet(options)
        if options['model'] =='MLP':
            model = MLPNet(options)
        else:
            model = LRNet(options)

        lambda_ = options.get('lambda', 1e-3) # regularization parameters
        optimizer = torch.optim.Adam(model.parameters(), lr=options['lr'])

        privacy_engine = PrivacyEngine(
                            model,
                            batch_size=options['bs'],
                            sample_size=len(self.train_loader.dataset),
                            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                            noise_multiplier=10.0,
                            max_grad_norm=options['C']
                        )

        privacy_engine.attach(optimizer)
        loss_func = nn.BCELoss()
        for epoch in range(options['epochs']):
            for inputs, targets in self.train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                targets = targets.reshape(-1,1)
                clf_loss = loss_func(outputs, targets)
                para_norm = torch.norm(torch.stack([torch.norm(p) for p in model.parameters()]))
                total_loss = clf_loss + lambda_ * para_norm
                total_loss.backward()
                optimizer.step()
             # Save the model and write logs

            self.write_logs(model)
        self.model = model