from torch.nn import Module
from opacus.grad_sample import (
    AbstractGradSampleModule,
    GradSampleModule,
    get_gsm_class,
    wrap_model,
)


def get_pp_model(module: Module):
    batch_first = True
    loss_reduction = "mean"
    grad_sample_mode = "hooks"
    model = None
    if isinstance(module, AbstractGradSampleModule):
            if (
                module.batch_first != batch_first
                or module.loss_reduction != loss_reduction
                or type(module) is not get_gsm_class(grad_sample_mode)
            ):
                raise ValueError(
                    f"Pre-existing GradSampleModule doesn't match new arguments."
                    f"Got: module.batch_first: {module.batch_first}, module.loss_reduction: {module.loss_reduction}, type(module): {type(module)}"
                    f"Requested: batch_first:{batch_first}, loss_reduction: {loss_reduction}, grad_sample_mode: {grad_sample_mode} "
                    f"Please pass vanilla nn.Module instead"
                )

            model = module
    else:
        model = wrap_model(
                module,
                grad_sample_mode=grad_sample_mode,
                batch_first=batch_first,
                loss_reduction=loss_reduction,
        )

    if not model:
         raise ValueError("Empty model detected.")
    
    model.forbid_grad_accumulation()
    return model



class IPPModel:
    def __init__(self, model: Module):
        self.model = get_pp_model(model)

    def get_model(self):
        return self.model
