import types
from .wan_forward import wan_forward
from .xfusers_wan_forward import xfusers_wan_forward
from .wan_block_forward import wan_block_forward
from .wan_block_forward import wan_block_taylor_forward
from .wan_block_forward import wan_block_scaling_forward

def enhance_model_forwards(self):
    # replace model / model.block forward function
    for block in self.blocks:
        if self.mode == "Taylor":
            block.forward = types.MethodType(wan_block_taylor_forward, block)
        elif self.mode == "Scaling":
            block.forward = types.MethodType(wan_block_scaling_forward, block)
        else:  # Default / Original
            block.forward = types.MethodType(wan_block_forward, block)

    if self.use_usp:
        self.forward = types.MethodType(xfusers_wan_forward, self)
    else:
        self.forward = types.MethodType(wan_forward, self)