import os
from typing import List

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA,
                              MergedColumnParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager, LoRAMapping)
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear


def test_from_lora_tensors(sql_lora_files):
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
    lora_model = LoRAModel.from_lora_tensors(1,
                                             8,
                                             16,
                                             tensors,
                                             "cuda",
                                             embeddings=new_embeddings)
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


def create_lora(lora_id: int, model: nn.Module,
                sub_modules: List[str]) -> LoRAModel:
    loras = {}
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
            torch.rand([w.shape[1], 8], device="cuda"),
            torch.rand([8, w.shape[0]], device="cuda"),
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
    loras = {}
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
            torch.rand([w.shape[1], 8], device="cuda"),
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
                       device="cuda"),
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
    manager = LoRAModelManager(model,
                               1,
                               1,
                               1,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=8,
                                          max_loras=8),
                               lora_target_modules=["dense1", "layer1.dense2"])
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


def test_lora_model_manager(dist_init, dummy_model):
    model = dummy_model
    model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
        lora_target_modules=["dense1", "dense2", "lm_head"])
    assert all(x is None for x in manager.lora_index_to_id)
    assert manager.add_lora(model_lora1)
    assert manager.activate_lora(1)
    assert manager.lora_index_to_id[0] == 1
    assert not manager.add_lora(model_lora1)
    assert not manager.activate_lora(1)
    assert manager.add_lora(model_lora2)
    assert manager.activate_lora(2)
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    assert not manager.add_lora(model_lora2)
    assert not manager.activate_lora(2)
    assert manager.add_lora(model_lora3)
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
        assert manager.activate_lora(3)
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    assert manager.remove_lora(model_lora2.id)
    assert manager.lora_index_to_id[1] is None
    assert not manager.remove_lora(model_lora2.id)
    assert manager.remove_lora(model_lora1.id)
    assert not manager.remove_lora(model_lora1.id)
    assert manager.add_lora(model_lora1)
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
    assert manager.add_lora(model_lora2)
    assert manager.activate_lora(3)
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
    assert manager.activate_lora(2)
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2


def test_lora_lru_cache_model_manager(dist_init, dummy_model):
    model = dummy_model
    model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
        lora_target_modules=["dense1", "dense2", "lm_head"])
    assert all(x is None for x in manager.lora_index_to_id)
    assert manager.add_lora(model_lora1)
    assert manager.activate_lora(1)
    assert manager.lora_index_to_id[0] == 1
    assert not manager.add_lora(model_lora1)
    assert not manager.activate_lora(1)
    assert manager.add_lora(model_lora2)
    assert manager.activate_lora(2)
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    assert not manager.add_lora(model_lora2)
    assert not manager.activate_lora(2)
    assert manager.add_lora(model_lora3)
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    assert manager.activate_lora(3)
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
    assert manager.remove_lora(model_lora2.id)
    assert manager.lora_index_to_id[1] is None
    assert not manager.remove_lora(model_lora2.id)
    assert manager.remove_lora(model_lora1.id)
    assert not manager.remove_lora(model_lora1.id)
    assert manager.add_lora(model_lora1)
    assert manager.activate_lora(1)
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    assert manager.add_lora(model_lora2)
    assert manager.deactivate_lora(3)
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
    assert manager.activate_lora(2)
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
    assert manager.activate_lora(3)
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3


def test_lru_lora_model_manager(dist_init, dummy_model):
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
    model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
    model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
    manager = LRUCacheLoRAModelManager(
        model, 2, 2, 2,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
        ["dense1", "dense2", "lm_head"])

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
    assert manager.add_lora(model_lora1)
    assert manager.add_lora(model_lora2)
    assert manager.activate_lora(1)
    assert manager.activate_lora(2)

    assert set(manager.list_loras()) == {1, 2}
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
    assert manager.add_lora(model_lora3)
    assert manager.add_lora(model_lora4)
    assert manager.activate_lora(3)
    assert manager.activate_lora(4)

    assert set(manager.list_loras()) == {3, 4}
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
    assert not manager.add_lora(model_lora3)
    assert not manager.activate_lora(3)
    assert manager.add_lora(model_lora2)
    assert manager.activate_lora(2)

    assert set(manager.list_loras()) == {3, 2}
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
    assert manager.remove_lora(3)
    assert not manager.remove_lora(3)

    assert set(manager.list_loras()) == {2}
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

    assert manager.add_lora(model_lora3)
    assert manager.activate_lora(3)
    assert manager.add_lora(model_lora4)
    assert manager.activate_lora(4)

    assert set(manager.list_loras()) == {3, 4}
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    assert manager.remove_oldest_lora()
    assert set(manager.list_loras()) == {4}
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

    assert manager.remove_oldest_lora()
    assert set(manager.list_loras()) == set()
    assert all(x is None for x in manager.lora_index_to_id)

    assert not manager.remove_oldest_lora()
    assert set(manager.list_loras()) == set()
    assert all(x is None for x in manager.lora_index_to_id)


def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
                                       sql_lora_files):
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
    worker_lora_manager = LRUCacheWorkerLoRAManager(
        4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
        torch.device("cuda"))
    worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)

    mapping = LoRAMapping([], [])
    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 2}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2

    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
    assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4

    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
    assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4

    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
    assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4

    worker_lora_manager.set_active_loras([
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
    assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6

    # Over capacity
    with pytest.raises(RuntimeError):
        worker_lora_manager.set_active_loras([
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)


def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
                             sql_lora_files):
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
    worker_lora_manager = WorkerLoRAManager(
        4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
        torch.device("cuda"))
    worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)

    mapping = LoRAMapping([], [])
    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 2}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2

    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 3, 4}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4

    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1, 2, 5}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5

    worker_lora_manager.set_active_loras([
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {1}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None

    worker_lora_manager.set_active_loras([
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
    assert worker_lora_manager.list_loras() == {6, 7, 8}
    assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
    assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
    assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7

    # Over capacity
    with pytest.raises(RuntimeError):
        worker_lora_manager.set_active_loras([
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)


def test_packed_loras(dist_init, dummy_model_gate_up):
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"])
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
        empty_replaced_module_name="gate_proj",
    )

    manager = LoRAModelManager(
        model, 2, 2, 2,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
        ["gate_up_proj"])
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
    assert manager.add_lora(model_lora)
    assert manager.add_lora(model_lora1)

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

    assert torch.allclose(packed_lora.lora_a[0],
                          model_lora.get_lora("gate_proj").lora_a)
    assert torch.allclose(packed_lora.lora_b[0],
                          model_lora.get_lora("gate_proj").lora_b)
    assert torch.allclose(packed_lora.lora_a[1],
                          model_lora.get_lora("up_proj").lora_a)
    assert torch.allclose(packed_lora.lora_b[1],
                          model_lora.get_lora("up_proj").lora_b)

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
    assert torch.allclose(packed_lora1.lora_a[1],
                          model_lora1.get_lora("up_proj").lora_a)
    assert torch.allclose(packed_lora1.lora_b[1],
                          model_lora1.get_lora("up_proj").lora_b)
