import types
from scaling_cache.forwards.flux.flux_forward import flux_forward
from scaling_cache.forwards.flux.double_stream_block_forward import double_stream_block_forward, double_stream_block_taylor_forward, double_stream_block_scaling_forward
from scaling_cache.forwards.flux.single_stream_block_forward import single_stream_block_forward, single_stream_block_taylor_forward, single_stream_block_scaling_forward


def enhance_model_forwards(self):
    # replace model / model.block forward function
    self.forward = types.MethodType(flux_forward, self)
    if self.mode == "Original":
        for block in self.double_blocks:
            block.forward = types.MethodType(double_stream_block_forward, block)
        for block in self.single_blocks:
            block.forward = types.MethodType(single_stream_block_forward, block)
    elif self.mode == "Taylor":
        for block in self.double_blocks:
            block.forward = types.MethodType(double_stream_block_taylor_forward, block)
        for block in self.single_blocks:
            block.forward = types.MethodType(single_stream_block_taylor_forward, block)
    elif self.mode == "Scaling" or update_alpha:
        for block in self.double_blocks:
            block.forward = types.MethodType(double_stream_block_scaling_forward, block)
        for block in self.single_blocks:
            block.forward = types.MethodType(single_stream_block_scaling_forward, block)
    