import torch
import torch.nn as nn


# Define the HyperNetwork
class HyperNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(HyperNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Define the target network with HyperLoRA
class TargetNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, hypernet, rank=4):
        super(TargetNetwork, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size)
        self.hypernet = hypernet
        self.rank = rank

        # Initialize low-rank adaptation matrices
        self.A = nn.Parameter(torch.randn(hidden_size, rank))
        self.B = nn.Parameter(torch.randn(rank, input_size))

    def forward(self, x, hyper_input):
        # Get the low-rank matrices from the HyperNetwork
        low_rank_weights = self.hypernet(hyper_input)
        A = low_rank_weights[:, :self.rank]
        B = low_rank_weights[:, self.rank:]

        # Compute the low-rank weight update
        weight_update = torch.mm(A, B)

        # Apply the weight update to the layer
        adapted_weight = self.fc.weight + weight_update

        x = torch.relu(nn.functional.linear(x, adapted_weight, self.fc.bias))
        return x


# Example usage
input_size = 10
hidden_size = 20
output_size = 10
hyper_input_size = 5

hypernet = HyperNetwork(hyper_input_size, 50, input_size * hidden_size // 2)  # Output size is reduced for low-rank
target_net = TargetNetwork(input_size, hidden_size, output_size, hypernet)

# Forward pass
x = torch.randn(1, input_size)
hyper_input = torch.randn(1, hyper_input_size)
output = target_net(x, hyper_input)
print(output)