import types
from .hy_forward import hy_forward
from .hy_block_forward import \
    hy_single_stream_block_taylor_forward, \
    hy_single_stream_block_scaling_forward, \
    hy_single_stream_block_forward, \
    hy_double_stream_block_taylor_forward, \
    hy_double_stream_block_scaling_forward, \
    hy_double_stream_block_forward

def enhance_model_forwards(self):
    self.forward = types.MethodType(hy_forward, self)

    for block in self.double_blocks:
        if self.mode == "Taylor":
            block.forward = types.MethodType(hy_double_stream_block_taylor_forward, block)
        elif self.mode == "Scaling":
            block.forward = types.MethodType(hy_double_stream_block_scaling_forward, block)
        else: # Original
            block.forward = types.MethodType(hy_double_stream_block_forward, block)

    for block in self.single_blocks:
        if self.mode == "Taylor":
            block.forward = types.MethodType(hy_single_stream_block_taylor_forward, block)
        elif self.mode == "Scaling":
            block.forward = types.MethodType(hy_single_stream_block_scaling_forward, block)
        else: # Original
            block.forward = types.MethodType(hy_single_stream_block_forward, block)