import torch
import torch.nn.functional as F
from nesim.utils.tensor_mapping import apply_mapping

batch_size = 10
in_features = 3
out_features = 2
x = torch.randn(batch_size, in_features)
weight = torch.randn(out_features, in_features)
bias = torch.randn(out_features)

"""
y = weight @ x + bias    
"""

assert weight.shape[1] == in_features
print(f"weight.shape: {weight.shape}")
print(f"bias.shape: {bias.shape}")
print(f"x.shape: {x.shape}")

y= F.linear(input=x, weight=weight, bias = bias)
print(f"y.shape: {y.shape}")

## now lets shuffle inputs with mapping: [1,0,2]
input_mapping = [1,0,2]
x_remapped = apply_mapping(
    x,
    mapping=input_mapping,
    dim = 1
)
weight_remapped = apply_mapping(
    weight,
    mapping=input_mapping,
    dim = 1
)
## no need to remap bias

print(f"[post input-remapping] weight.shape: {weight_remapped.shape}")
print(f"[post input-remapping] x.shape: {x_remapped.shape}")

y1 = F.linear(input=x_remapped, weight=weight_remapped, bias = bias)
print(f"[post input-remapping] y.shape: {y1.shape}")

## even if the input space is remapped, the output remains identical
assert (y1 == y).all()


## now lets shuffle inputs with mapping: [1,0,2]
output_mapping = [1,0]

## no need to remap x
weight_remapped = apply_mapping(
    weight,
    mapping=output_mapping,
    dim = 0
)
bias_remapped = apply_mapping(
    bias,
    mapping=output_mapping,
    dim = 0
)
print(f"[post output-remapping] weight.shape: {weight_remapped.shape}")
print(f"[post output-remapping] bias.shape: {bias_remapped.shape}")

y1 = F.linear(input=x, weight=weight_remapped, bias = bias_remapped)
print(f"[post output-remapping] y.shape: {y1.shape}")

## the output now gets remapped along the last axis, but still remains the same
assert (y1 == apply_mapping(y, mapping=output_mapping, dim = 1)).all()