import torch
from torch import nn


def fc_nn(indim: int, hidden_dim: list[int], outdim: int, activation=nn.ReLU()) -> nn.Sequential:
    layers = []
    hidden_dim = [indim] + hidden_dim
    for i in range(len(hidden_dim)-1):
        layers.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
        layers.append(activation)
    layers.append(nn.Linear(hidden_dim[-1], outdim))
    model = nn.Sequential(*layers)
    return model


def model_difference(model1: nn.Module, model2: nn.Module, check_equality=True):
    if check_equality:
        model1_names = list(next(iter(zip(*model1.named_parameters()))))
        model2_names = list(next(iter(zip(*model2.named_parameters()))))
        assert model1_names == model2_names
    diffs = {}
    for (n, p1), p2 in zip(model1.named_parameters(), model2.parameters()):
        diffs[n] = (torch.linalg.vector_norm((p1-p2).flatten()).item())
    return diffs
