
from typing import Dict, Union, Optional, Iterable, Tuple

import torch
import torch.nn as nn
from transformers import PreTrainedModel

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.distributed.communication_op import tensor_model_parallel_all_gather

from .config import ModelConfig, LoadFormat, LoadConfig
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
from .hf_weight_loader import update_hf_weight_loader


def get_model(actor_model: Union[PreTrainedModel, Dict],
              model_config: ModelConfig,
              load_config: LoadConfig,
              device_config: DeviceConfig,
              parallel_config: ParallelConfig,
              scheduler_config: SchedulerConfig,
              lora_config: Optional[LoRAConfig],
              multimodal_config: Optional[MultiModalConfig],
              cache_config: CacheConfig = None) -> nn.Module:
    loader = get_model_loader(load_config)
    if load_config.load_format.startswith('dummy'):
        return loader.load_model(model_config=model_config,
                                 device_config=device_config,
                                 lora_config=lora_config,
                                 multimodal_config=multimodal_config,
                                 parallel_config=parallel_config,
                                 scheduler_config=scheduler_config,
                                 cache_config=cache_config)
    else:
        return loader.load_model(actor_model=actor_model,
                                 model_config=model_config,
                                 device_config=device_config,
                                 lora_config=lora_config,
                                 multimodal_config=multimodal_config,
                                 parallel_config=parallel_config,
                                 scheduler_config=scheduler_config,
                                 cache_config=cache_config)


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:


    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.AUTO:
        update_megatron_weight_loader()
        return MegatronLoader(load_config)


    if load_config.load_format == LoadFormat.MEGATRON:
        update_megatron_weight_loader()
        return MegatronLoader(load_config)

    if load_config.load_format == LoadFormat.HF:
        update_hf_weight_loader()
        return HFLoader(load_config)

    if load_config.load_format == LoadFormat.DTENSOR:
        update_dtensor_weight_loader()
        return DTensorLoader(load_config)

    if load_config.load_format == LoadFormat.DUMMY_HF:
        update_hf_weight_loader()
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
        update_megatron_weight_loader()
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
        update_dtensor_weight_loader()
        return DummyModelLoader(load_config)

    raise ValueError('load format not supported in verl: {}, only support {} and {}'.format(
        load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))


class DummyModelLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
                   multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
                                          scheduler_config)

        return model.eval()


class MegatronLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):

        pass


    def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig,
                   device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
                   multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
                                          scheduler_config)


            if isinstance(actor_model, nn.Module):
                load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
                                      vllm_model=model)
            else:
                load_megatron_weights(actor_weights=actor_model, vllm_model=model)

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

        model = model.cuda()
        return model.eval()


class HFLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
        if isinstance(actor_model, Dict):
            return actor_model.items()
        elif isinstance(actor_model, nn.Module):
            return dict(actor_model.named_parameters()).items()
        else:
            raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}')

    def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig,
                   device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
                   multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):

            model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
                                      scheduler_config)
            model.load_weights(self._get_weights_iterator(actor_model))
            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

        model = model.cuda()
        return model.eval()


class DTensorLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):

        pass


    def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig,
                   device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
                   multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
                                          scheduler_config)


            if isinstance(actor_model, nn.Module):
                load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
                                     vllm_model=model)
            else:
                load_dtensor_weights(actor_weights=actor_model, vllm_model=model)

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

        model = model.cuda()
        return model.eval()



def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:

    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
    logits = tensor_model_parallel_all_gather(logits)

    if logits is not None:
        logits = logits[:, :self.org_vocab_size]
    return logits


from vllm.model_executor.layers.logits_processor import LogitsProcessor


def logitsprocessor_init(self,
                         vocab_size: int,
                         org_vocab_size: Optional[int] = None,
                         scale: float = 1.0,
                         logits_as_input: bool = False,
                         soft_cap: Optional[float] = None) -> None:

    super(LogitsProcessor, self).__init__()
    self.scale = scale
    self.vocab_size = vocab_size

    self.logits_as_input = logits_as_input

    self.org_vocab_size = org_vocab_size or vocab_size

    self.soft_cap = soft_cap

    self.use_gather = False


LogitsProcessor.__init__ = logitsprocessor_init
