from typing import Type

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.attack.auxiliary.utils import get_data_property


def wrap_ActivePIATrainer(
        base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
    base_trainer.ctx.alpha_prop_loss = base_trainer._cfg.attack.alpha_prop_loss


def hood_on_batch_start_get_prop(ctx):
    ctx.prop = get_data_property(ctx.data_batch)


def hook_on_batch_forward_add_PIA_loss(ctx):
    ctx.loss_batch = ctx.alpha_prop_loss * ctx.loss_batch + (
        1 - ctx.alpha_prop_loss) * ctx.criterion(ctx.y_prob, ctx.prop)
