
from typing import Dict, Union, Optional, Iterable, Tuple

import torch
import torch.nn as nn
from transformers import PreTrainedModel

from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.distributed.communication_op import tensor_model_parallel_all_gather

from .config import ModelConfig, LoadFormat, LoadConfig
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
from .hf_weight_loader import update_hf_weight_loader


def get_model(actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, load_config: LoadConfig,
              device_config: DeviceConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig,
              lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
    loader = get_model_loader(load_config)
    if load_config.load_format.startswith('dummy'):
        return loader.load_model(model_config=model_config,
                                 device_config=device_config,
                                 lora_config=lora_config,
                                 vision_language_config=vision_language_config,
                                 parallel_config=parallel_config,
                                 scheduler_config=scheduler_config)
    else:
        return loader.load_model(actor_model=actor_model,
                                 model_config=model_config,
                                 device_config=device_config,
                                 lora_config=lora_config,
                                 vision_language_config=vision_language_config,
                                 parallel_config=parallel_config,
                                 scheduler_config=scheduler_config)


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:


    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.AUTO:
        update_megatron_weight_loader()
        return MegatronLoader(load_config)


    if load_config.load_format == LoadFormat.MEGATRON:
        update_megatron_weight_loader()
        return MegatronLoader(load_config)

    if load_config.load_format == LoadFormat.HF:
        update_hf_weight_loader()
        return HFLoader(load_config)

    if load_config.load_format == LoadFormat.DTENSOR:
        update_dtensor_weight_loader()
        return DTensorLoader(load_config)

    if load_config.load_format == LoadFormat.DUMMY_HF:
        update_hf_weight_loader()
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
        update_megatron_weight_loader()
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
        update_dtensor_weight_loader()
        return DummyModelLoader(load_config)

    raise ValueError('load format not supported in verl: {}, only support {} and {}'.format(
        load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))


class DummyModelLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
                   vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)

        return model.eval()


class MegatronLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):

        pass


    def load_model(self, actor_model: Union[PreTrainedModel,
                                            Dict], model_config: ModelConfig, device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)


            if isinstance(actor_model, nn.Module):
                load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
                                      vllm_model=model)
            else:
                load_megatron_weights(actor_weights=actor_model, vllm_model=model)

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

        model = model.cuda()
        return model.eval()


class HFLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
        if isinstance(actor_model, Dict):
            return actor_model.items()
        elif isinstance(actor_model, nn.Module):
            return dict(actor_model.named_parameters()).items()
        else:
            raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}')

    def load_model(self, actor_model: Union[PreTrainedModel,
                                            Dict], model_config: ModelConfig, device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):

            model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
            model.load_weights(self._get_weights_iterator(actor_model))
            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

        model = model.cuda()
        return model.eval()


class DTensorLoader(BaseModelLoader):


    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):

        pass


    def load_model(self, actor_model: Union[PreTrainedModel,
                                            Dict], model_config: ModelConfig, device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)


            if isinstance(actor_model, nn.Module):
                load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
                                     vllm_model=model)
            else:
                load_dtensor_weights(actor_weights=actor_model, vllm_model=model)

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

        model = model.cuda()
        return model.eval()



def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:

    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
    logits = tensor_model_parallel_all_gather(logits)

    if logits is not None:
        logits = logits[:, :self.org_vocab_size]
    return logits


from vllm.model_executor.layers.logits_processor import LogitsProcessor

LogitsProcessor._get_logits = _get_logits
