import gc

import torch
import torch.nn as nn
import torch.optim as optim
from torch.profiler import ProfilerActivity, profile, record_function


class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, r, weight_based=True):
        """
        A simple LoRA layer which applies low-rank adaptation to a frozen linear layer.

        Args:
            in_features (int): Input features size.
            out_features (int): Output features size.
            r (int): Rank for the LoRA adaptation.
            weight_based (bool): flag for whether to use weight based or activation based lora
        """
        super(LoRALayer, self).__init__()
        self.weight = nn.Parameter(
            torch.randn(out_features, in_features), requires_grad=False
        )
        self.bias = None
        self.weight_based = weight_based

        self.lora_A = nn.Parameter(torch.randn(r, in_features))
        self.lora_B = nn.Parameter(torch.randn(out_features, r))

    def forward(self, x):
        if self.weight_based:

            ab = self.lora_B @ self.lora_A
            abw = ab + self.weight
            return torch.nn.functional.linear(x, abw, self.bias)
        else:
            result = x @ self.weight.t()
            result2 = x @ self.lora_A.t()
            result3 = result2 @ self.lora_B.t()
            result4 = result + result3
            return result4


class MultiLoRALayer(nn.Module):
    def __init__(self, layers_config, r, weight_based=True):
        """
        A multi-layer model applying LoRA to multiple linear layers.

        Args:
            layers_config (list of tuple): List of (in_features, out_features) tuples for each layer.
            r (int): Rank for the LoRA adaptation.
        """
        super(MultiLoRALayer, self).__init__()
        self.layers = nn.ModuleList(
            [
                LoRALayer(in_features, out_features, r, weight_based)
                for in_features, out_features in layers_config
            ]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


# Example usage

device = "cpu"  # Ensure the device is set to CPU
print(device)

rank = 8
batch_size = 8
seq_len = 1024
layers_config = [(2048, 2048)]
layers_config *= 1
in_features = layers_config[0][0]
print("MLP configuration:", layers_config)

input_tensor = torch.randn(batch_size, seq_len, in_features, requires_grad=True).to(
    device
)

# If you wish to try out the memory profile for weight-based and activation-based, please change here
multi_lora_layer = MultiLoRALayer(layers_config, rank, weight_based=False).to(device)

for layer in multi_lora_layer.layers:
    layer.weight.requires_grad = False

optimizer = optim.SGD(multi_lora_layer.parameters(), lr=0.001)


def run_training_step(input_tensor, lora_layer, optimizer, loss_fn):
    """
    Function to perform one training step with memory tracking.

    Args:
        input_tensor (torch.Tensor): The input tensor for the linear layer.
        lora_layer (LoRALayer): The LoRA linear layer.
        optimizer (torch.optim.Optimizer): The optimizer.
        loss_fn (function): The loss function.

    Returns:
        float: Peak memory usage in MB.
    """
    for _ in range(1):
        optimizer.zero_grad()
        output = lora_layer(input_tensor)
        loss = torch.abs(output).mean()
        loss.backward()
        optimizer.step()


gc.collect()


def print_prof(prof, which):
    print(
        prof.key_averages(group_by_input_shape=True).table(
            sort_by="self_cpu_memory_usage", row_limit=20
        )
    )
    prof.export_chrome_trace(f"trace_{which}.json")

    print(
        "max_memory_usage",
        max([event.cpu_memory_usage for event in prof.key_averages()]) / 1024**2,
    )


gc.collect()

with profile(
    activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True
) as prof:
    with record_function("model_inference"):
        run_training_step(input_tensor, multi_lora_layer, optimizer, None)

print_prof(prof)
