# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import torch


def test_lora_layer_replacement(lit_llama):
    from lit_llama.lora import lora, CausalSelfAttention as LoRACausalSelfAttention
    from lit_llama.model import LLaMA, LLaMAConfig
    
    config = LLaMAConfig()
    config.n_layer = 2
    config.n_head = 4
    config.n_embd = 8
    config.block_size = 8
    config.vocab_size = 8

    with lora(r=8, alpha=8, dropout=0.1):
        model = LLaMA(config)

    assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention)
    assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention)


def test_lora_merge_unmerge(lit_llama):
    from lit_llama.lora import lora, mark_only_lora_as_trainable
    from lit_llama.model import LLaMA, LLaMAConfig
    
    config = LLaMAConfig(n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8)

    with lora(r=8, alpha=8, dropout=0.1):
        model = LLaMA(config)
    
    initial_weight = model.transformer.h[0].attn.c_attn.weight.clone()
    model.train()
    assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight)

    # perform an update to the LoRA weights
    mark_only_lora_as_trainable(model)
    optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64)).sum().backward()
    optimizer.step()
    optimizer.zero_grad()
    # the weight remains unchanged (only lora A and B change)
    assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight)

    # 'merge' and then 'unmerge' should neutralize themselves
    weight_before = model.transformer.h[0].attn.c_attn.weight.clone()
    model.eval()
    assert not torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_before)
    model.train()
    # note: numerically, `W + (A * B) - (A * B) == W` does not hold exactly
    assert torch.allclose(model.transformer.h[0].attn.c_attn.weight, weight_before)

    # calling eval/train multiple times in a row should not merge/unmerge multiple times
    model.eval()
    assert model.transformer.h[0].attn.c_attn.merged
    weight_after = model.transformer.h[0].attn.c_attn.weight.clone()
    model.eval()
    model.eval()
    assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after)
    model.train()
    assert not model.transformer.h[0].attn.c_attn.merged
    weight_after = model.transformer.h[0].attn.c_attn.weight.clone()
    model.train()
    model.train()
    assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after)
