from ._base import ModulePosthoc
from .._reducer import ModulePosthocReduce

class RawAttention(ModulePosthoc):

    def solve(self, forward_in, forward_out, backward_in, backward_out):
        return backward_in.mean(dim=1)
ModulePosthocReduce.register(RawAttention, RawAttention, method=lambda a, b: b*a)