from typing import Dict, List, Optional, Set, Tuple

import torch

from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.utils import (
    ROW_PARALLELISM_LINEAR_LORA_NAMES,
    LoRAType,
    get_hidden_dim,
    get_stacked_multiply,
    get_weight_name,
)


class LoRAMemoryPool:
    """Class for memory pool management of lora modules"""

    def __init__(
        self,
        base_hf_config: AutoConfig,
        max_loras_per_batch: int,
        max_lora_dim: int,
        dtype: torch.dtype,
        tp_size: int,
        tp_rank: int,
        lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
    ):

        self.base_hf_config: AutoConfig = base_hf_config
        self.num_layer: int = base_hf_config.num_hidden_layers
        self.max_loras_per_batch: int = max_loras_per_batch
        self.max_lora_dim: int = max_lora_dim
        self.dtype: torch.dtype = dtype
        self.tp_size: int = tp_size
        self.tp_rank: int = tp_rank
        self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules

        # Both A_buffer and B_buffer maps lora weight names to its buffer space.
        # A_buffer contains num_layer number of row-major tensors with shape
        #   (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
        # B_buffer contains num_layer number of column-major tensors with shape
        #   (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
        self.A_buffer: Dict[str, List[torch.Tensor]] = {}
        self.B_buffer: Dict[str, List[torch.Tensor]] = {}

        # Lora uid -> buffer idx in memory pool
        self.uid_to_buffer_id: Dict[Optional[str], int] = {}

        # Buffer idx -> lora uid in memory pool
        # All uids are initalized as empty strings for empty buffer slots
        # Here we don't initalize to None since None is a valid uid
        self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch

    def get_lora_A_shape(
        self, module_name: str, base_model: torch.nn.Module
    ) -> Tuple[int]:
        """
        Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
        """
        input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
        c = get_stacked_multiply(module_name)
        if self.tp_size > 1:
            if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
                input_dim = divide(input_dim, self.tp_size)
        return (
            self.max_loras_per_batch,
            self.max_lora_dim * c,
            input_dim,
        )

    def get_lora_B_shape(
        self, module_name: str, base_model: torch.nn.Module
    ) -> Tuple[int]:
        """
        Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
        """
        _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
        c = get_stacked_multiply(module_name)
        if self.tp_size > 1:
            if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
                output_dim = divide(output_dim, self.tp_size)
        return (
            c,
            self.max_loras_per_batch,
            output_dim,
            self.max_lora_dim,
        )

    def init_buffers(
        self,
        lora_weight_names: Set[Tuple[str]],
        base_model: torch.nn.Module,
    ):

        # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
        #   e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
        self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
        device = next(base_model.parameters()).device
        lora_module_A_names = set([name[0] for name in lora_weight_names])
        lora_module_B_names = set([name[1] for name in lora_weight_names])
        # Init A tensor, column_major=False
        for module_A in lora_module_A_names:
            lora_A_shape = self.get_lora_A_shape(module_A, base_model)
            self.A_buffer[module_A] = [
                torch.empty(
                    lora_A_shape,
                    dtype=self.dtype,
                    device=device,
                )
                for i in range(self.num_layer)
            ]
        # Init B tensor, column_major=True
        for module_B in lora_module_B_names:
            lora_B_shape = self.get_lora_B_shape(module_B, base_model)
            self.B_buffer[module_B] = [
                torch.empty(
                    lora_B_shape,
                    dtype=self.dtype,
                    device=device,
                )
                for _ in range(self.num_layer)
            ]

    def prepare_lora_batch(
        self,
        cur_uids: Set[Optional[str]],
        lora_adapters: Dict[str, LoRAAdapter],
    ):

        def get_available_buffer_slot():
            for buffer_id in range(self.max_loras_per_batch):
                # Prioritize empty slots
                if self.buffer_id_to_uid[buffer_id] == "":
                    return buffer_id, ""

            for buffer_id in range(self.max_loras_per_batch):
                # Evict unneeded lora
                if self.buffer_id_to_uid[buffer_id] not in cur_uids:
                    return buffer_id, self.buffer_id_to_uid[buffer_id]

            raise ValueError(
                "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
            )

        for uid in cur_uids:
            if uid not in self.uid_to_buffer_id:
                buffer_id, evicted_lora_uid = get_available_buffer_slot()
                if evicted_lora_uid != "":
                    self.uid_to_buffer_id.pop(evicted_lora_uid)
                self.load_lora_weight_to_buffer(
                    uid, buffer_id, lora_adapters.get(uid, None)
                )
                self.uid_to_buffer_id[uid] = buffer_id
                self.buffer_id_to_uid[buffer_id] = uid

    def load_lora_weight_to_buffer(
        self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
    ):

        if uid is None:
            for i in range(self.num_layer):
                for k in self.A_buffer.keys():
                    self.A_buffer[k][i][buffer_id] = 0
            return

        assert lora_adapter is not None
        lora_rank = lora_adapter.config.hf_config["r"]
        for layer_id in range(self.num_layer):
            layer_weights = lora_adapter.layers[layer_id].weights
            temp_A_buffer: Dict[str, torch.Tensor] = {}
            temp_B_buffer: Dict[str, torch.Tensor] = {}
            for name, weights in layer_weights.items():
                if "lora_A" in name:
                    lora_weight_name = get_weight_name(
                        name, self.lora_weight_names, LoRAType.LORA_A
                    )
                    temp_A_buffer[lora_weight_name] = weights
                else:
                    lora_weight_name = get_weight_name(
                        name, self.lora_weight_names, LoRAType.LORA_B
                    )
                    temp_B_buffer[lora_weight_name] = weights

            if self.tp_size > 1:
                cur_layer_modules = self.lora_modules[layer_id]
                for module_name, module in cur_layer_modules:
                    if "qkv_proj" in module_name:
                        temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
                            temp_A_buffer["qkv_proj"], self.tp_rank
                        )
                        temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
                            module.slice_lora_b_weights(
                                [temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
                                self.tp_rank,
                            )
                        )
                    else:
                        weight_name = get_weight_name(
                            module_name, self.lora_weight_names, LoRAType.LORA_A
                        )
                        temp_A_buffer[weight_name] = module.slice_lora_a_weights(
                            temp_A_buffer[weight_name], self.tp_rank
                        )
                        temp_B_buffer[weight_name] = module.slice_lora_b_weights(
                            temp_B_buffer[weight_name], self.tp_rank
                        )

            for name, weights in temp_A_buffer.items():
                c = get_stacked_multiply(name)
                self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
                    weights
                )

            for name, weights in temp_B_buffer.items():
                c = get_stacked_multiply(name)
                if c > 1:
                    for stacked_id in range(c):
                        self.B_buffer[name][layer_id][stacked_id][buffer_id][
                            :, :lora_rank
                        ].copy_(weights[stacked_id])
                else:
                    self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
                        weights
                    )

    def get_tensor(
        self, weight_name: str, layer_id: int, lora_type: LoRAType
    ) -> torch.Tensor:

        if lora_type == LoRAType.LORA_A:
            return self.A_buffer[weight_name][layer_id]

        return self.B_buffer[weight_name][layer_id]

    def get_buffer_id(self, lora_uid: str):
        return self.uid_to_buffer_id[lora_uid]
