from typing import Tuple

from opacus.accountants import IAccountant, create_accountant
from opacus.grad_sample import AbstractGradSampleModule
from torch import nn, optim
from torch.utils.data import DataLoader

from ipp import IPP
from ipp.data_loader import IPPDataLoader
from ipp.model import IPPModel
from ipp.optimizer import IPPOptimizer


class IPrivacyEngine:
    """
    Modifies the original `opacus.PrivacyEngine` for individualized privacy progression.

    Attributes:
        accountant (IAccountant): The privacy accountant to be used.
    """
    def __init__(self):
        self.accountant = create_accountant(mechanism='rdp')

    
    def make_private(self, *, 
                     data_loader: DataLoader,
                     model: nn.Module,
                     optimizer: optim.Optimizer,
                     ipp: IPP) -> Tuple[IPPDataLoader, AbstractGradSampleModule, IPPOptimizer]:
        """
        Privatizes the given `data_loader`, `model` and `optimizer` according to `ipp`.
        """
        ipp_data_loader = IPPDataLoader.from_data_loader(data_loader, ipp)
        ipp_model = IPPModel(model).get_model()
        ipp_optimizer = IPPOptimizer(optimizer)
        return (ipp_data_loader, ipp_model, ipp_optimizer)
    