# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader

from typing import Dict

import torch.nn as nn
from vllm.model_executor.model_loader.utils import set_default_torch_dtype


def update_hf_weight_loader():
    print("no hf weight loader need to be updated")
    return


def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
    assert isinstance(actor_weights, Dict)
    with set_default_torch_dtype(next(vllm_model.parameters()).dtype):  # TODO
        if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
            del actor_weights["lm_head.weight"]
        vllm_model.load_weights(actor_weights.items())
    for _, module in vllm_model.named_modules():
        quant_method = getattr(module, "quant_method", None)
        if quant_method is not None:
            quant_method.process_weights_after_loading(module)
        # FIXME: Remove this after Mixtral is updated
        # to use quant_method.

        # Handle process_weights_after_loading method for each layer
        if hasattr(module, "process_weights_after_loading"):
            # Use inspect to check method signature
            import inspect
            signature = inspect.signature(module.process_weights_after_loading)

            # Call differently based on signature
            if 'act_dtype' in signature.parameters:
                # act_dtype parameter is required
                if hasattr(module, 'dtype'):
                    act_dtype = module.dtype
                else:
                    # Try to get dtype from module parameters
                    try:
                        act_dtype = next(module.parameters()).dtype
                    except (StopIteration, RuntimeError):
                        # If module has no parameters, use model's dtype
                        act_dtype = next(vllm_model.parameters()).dtype

                module.process_weights_after_loading(act_dtype)
            else:
                # No act_dtype parameter needed
                module.process_weights_after_loading()
    vllm_model = vllm_model.cuda()
