import numpy as np
import torch

rank = 8
shape = 1024

def plus_matrices(matrices):
    result = matrices[0]
    for matrix in matrices[1:]:
        result = result + matrix

    return np.linalg.matrix_rank(result)

def multiple_matrices(matrices):
    result = matrices[0]
    for matrix in matrices[1:]:
        result = result @ matrix

    return np.linalg.matrix_rank(result)

matrices = []

for i in range(3):
    A = np.random.random((shape, rank))
    B = np.random.random((rank, shape))
    W = A @ B
    matrices.append(W)

print([np.linalg.matrix_rank(m) for m in matrices])
print(plus_matrices(matrices))
print(multiple_matrices(matrices))

# ------------------------------------------
w_shape = (4096, 4096)
in_features = 4096
out_features = 4096
r0 = 4
r1 = 2
r2 = 2
layers = 32
modules = 3
groups = 32

local = layers * modules * (w_shape[0] * r0 + r0 * w_shape[1])
# intra = layers * (modules * (w_shape[0] + in_features + out_features + w_shape[1]) + (
#         in_features * r1 + r1 * out_features))
# inter = layers * modules * (w_shape[0] + in_features + out_features + w_shape[1]) + (
#         in_features * r2 + r2 * out_features)

intra = layers * (modules * (w_shape[0] + in_features + out_features + w_shape[1]))
inter = layers * modules * (w_shape[0] + in_features + out_features + w_shape[1])

trainable_param = local + intra + inter

print("lora r 8: ", layers * modules * (w_shape[0] * 8 + 8 * w_shape[1]))
print("ours:", trainable_param)

a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
up = torch.randn(2, 1)
down = torch.randn(1, 4)

result = a @ up @ down
print(result)

def compute_trainable_parameters(target_modules):
    layer_num = 32
    r = [4, 0, 4]

    module_shapes = {
        "q": (4096, 4096),
        "k": (4096, 4096),
        "v": (4096, 4096),
        "up": (4096, 11008),
        "down": (11008, 4096)
    }

    in_features = 4096
    out_features = 4096

    local = 0
    intra = 0
    inter = 0

    for module in target_modules:
        shape = module_shapes[module]

        local += layer_num * (shape[0] + shape[1]) * r[0]

        if shape[0] != in_features:
            intra += shape[0] + in_features
            inter += shape[0] + in_features

        if shape[1] != out_features:
            intra += shape[1] + out_features
            inter += shape[1] + out_features

    intra = layer_num * (in_features + out_features) * r[1] if r[1] > 0 else 0
    inter = (in_features + out_features) * r[2] if r[2] > 0 else 0

    return local + intra + inter

parameters = compute_trainable_parameters(["q", "k", "v", "up", "down"])
print(parameters)  # 8011776
