"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper
"""
import importlib.util

import torch
import torch.nn as nn

def replace_parameters_by_name(module, name_keywords, device):
    from torch import nn
    for name, param in list(module.named_parameters(recurse=False)):
        if any(keyword in name for keyword in name_keywords):
            if isinstance(param, nn.Parameter):
                tensor = param.data
                delattr(module, name)
                setattr(module, name, tensor.to(device=device))
    for child_name, child_module in module.named_children():
        replace_parameters_by_name(child_module, name_keywords, device)

if importlib.util.find_spec("pai_fuser") is not None:
    from pai_fuser.core import (convert_model_weight_to_float8,
                                convert_weight_dtype_wrapper)
    print("Enable PAI Quantization Turbo")
else:
    def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens'], device=None):
        for name, module in model.named_modules():
            flag = False
            for _exclude_module_name in exclude_module_name:
                if _exclude_module_name in name:
                    flag = True
            if flag:
                continue
            for param_name, param in module.named_parameters():
                flag = False
                for _exclude_module_name in exclude_module_name:
                    if _exclude_module_name in param_name:
                        flag = True
                if flag:
                    continue
                param.data = param.data.to(torch.float8_e4m3fn)

    def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
        weight_dtype = cls.weight.dtype
        cls.to(origin_dtype)

        # Convert all inputs to the original dtype
        inputs = [input.to(origin_dtype) for input in inputs]
        out = cls.original_forward(*inputs, **kwargs)

        cls.to(weight_dtype)
        return out

    def convert_weight_dtype_wrapper(module, origin_dtype):
        for name, module in module.named_modules():
            if name == "" or "embed_tokens" in name:
                continue
            original_forward = module.forward
            if hasattr(module, "weight") and module.weight is not None:
                setattr(module, "original_forward", original_forward)
                setattr(
                    module,
                    "forward",
                    lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
                )