
import contextlib
from typing import Dict, Type, Union

import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from megatron.core.tensor_parallel.utils import VocabUtility

from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights)

from .config import ModelConfig
from vllm.config import DeviceConfig, LoRAConfig
from .weight_loaders import *
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sequence import SamplerOutput
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import _prune_hidden_states, _apply_logits_processors, _apply_penalties, _apply_top_k_top_p, _apply_min_p, _apply_penalties, _sample, _get_logprobs, _build_sampler_output


@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):

    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        model_cls = ModelRegistry.load_model_cls(arch)
        if model_cls is not None:
            return model_cls
    raise ValueError(f"Model architectures {architectures} are not supported for now. "
                     f"Supported architectures: {ModelRegistry.get_supported_archs()}")


from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.activation import ScaledActivation

__LAYER_WEIGHT_LOADER_REGISTRY__ = {
    ColumnParallelLinear: parallel_weight_loader,
    MergedColumnParallelLinear: parallel_weight_loader,
    QKVParallelLinear: parallel_weight_loader,
    RowParallelLinear: parallel_weight_loader,
    VocabParallelEmbedding: parallel_weight_loader,
    ParallelLMHead: parallel_weight_loader

}


for layer_class, weight_loader in __LAYER_WEIGHT_LOADER_REGISTRY__.items():
    layer_class.weight_loader = weight_loader

__MODEL_WEIGHT_LOADER_REGISTRY__ = {
    'GPT2LMHeadModel': gpt2_weight_loader,
    'LlamaForCausalLM': llama_weight_loader,
    'LLaMAForCausalLM': llama_weight_loader,
    'MistralForCausalLM': mistral_weight_loader,
}


DEFAULT_VOCAB_PADDING_SIZE = 64


def vocab_init(self,
               num_embeddings: int,
               embedding_dim: int,
               params_dtype: Optional[torch.dtype] = None,
               org_num_embeddings: Optional[int] = None,
               padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
    super(VocabParallelEmbedding, self).__init__()


    self.num_embeddings = num_embeddings
    self.org_vocab_size = org_num_embeddings or num_embeddings


    self.embedding_dim = embedding_dim
    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    self.tp_size = get_tensor_model_parallel_world_size()


    self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size(
        self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size))
    self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index)
    self.weight = Parameter(
        torch.empty(
            self.num_embeddings_per_partition,
            self.embedding_dim,

            dtype=params_dtype))
    set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader})


VocabParallelEmbedding.__init__ = vocab_init


def _get_model_weight_loader(arch: str):
    if arch in __MODEL_WEIGHT_LOADER_REGISTRY__:
        return __MODEL_WEIGHT_LOADER_REGISTRY__[arch]
    raise ValueError(f"Model architectures {arch} are not supported for now. "
                     f"Supported architectures: {ModelRegistry.get_supported_archs()}")


def get_model(actor_model: Union[PreTrainedModel, Dict],
              model_config: ModelConfig,
              device_config: DeviceConfig,
              lora_config: Optional[LoRAConfig] = None) -> nn.Module:
    model_class = _get_model_architecture(model_config.hf_config)


    linear_method = None
    quant_config = None
    if model_config.quantization is not None:
        quant_config = get_quant_config(model_config.quantization, model_config.model, model_config.hf_config,
                                        model_config.download_dir)
        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(f"The quantization method {model_config.quantization} is not "
                             "supported for the current GPU. "
                             f"Minimum capability: {quant_config.get_min_capability()}. "
                             f"Current capability: {capability}.")
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(f"{model_config.dtype} is not supported for quantization "
                             f"method {model_config.quantization}. Supported dtypes: "
                             f"{supported_dtypes}")
        linear_method = quant_config.get_linear_method()

    with _set_default_torch_dtype(model_config.dtype):

        model = model_class(model_config.hf_config, linear_method)

        if model_config.load_format == "dummy":
            model = model.cuda()

            initialize_dummy_weights(model)
        elif model_config.load_format == 'model' or model_config.load_format == 'auto':

            if isinstance(actor_model, nn.Module):
                load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
            else:
                load_weights(actor_weights=actor_model, vllm_model=model)


        model = model.cuda()
    return model.eval()


def load_weights(actor_weights: Dict, vllm_model: nn.Module):
    weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
    weight_loader(actor_weights, vllm_model)

    vllm_model = vllm_model.cuda()



def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:

    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
    logits = tensor_model_parallel_all_gather(logits)

    if logits is not None:
        logits = logits[:, :self.org_vocab_size]
    return logits


def forward(
    self,
    embedding: torch.Tensor,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:

    hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)


    logits = self._get_logits(hidden_states, embedding, embedding_bias)

    origin_logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

    if not sampling_metadata.perform_sampling:
        return None

    assert logits is not None
    _, vocab_size = logits.shape


    logits = _apply_logits_processors(logits, sampling_metadata)


    (sampling_tensors, do_penalties, do_top_p_top_k,
     do_min_p) = SamplingTensors.from_sampling_metadata(sampling_metadata, vocab_size, logits.device, logits.dtype)


    if do_penalties:
        logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, sampling_tensors.output_tokens,
                                  sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties,
                                  sampling_tensors.repetition_penalties)


    logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

    if do_top_p_top_k:
        logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)

    if do_min_p:
        logits = _apply_min_p(logits, sampling_tensors.min_ps)


    probs = torch.softmax(logits, dim=-1, dtype=torch.float)

    logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)


    sample_results = _sample(probs, logprobs, sampling_metadata)


    prompt_logprobs, sample_logprobs = _get_logprobs(origin_logprobs, sampling_metadata, sample_results)

    return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs)


from vllm.model_executor.layers.sampler import Sampler

Sampler._get_logits = _get_logits
Sampler.forward = forward
