import torch
import numpy as np
from torch import nn
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
class LinearModel(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 128, layer_num: int = 3):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim)] +
            [nn.Linear(hidden_dim, hidden_dim) for _ in range(layer_num - 2)] +
            [nn.Linear(hidden_dim, output_dim)]
        )
        self.activation = nn.ReLU()
    def forward(self, x):
        for layer in self.layers:
            x = self.activation(layer(x))
        return x
def nonlinear_mapping(corpus_emb_1, corpus_emb_2, overlap_ids, test_emb_1, test_emb_2, epoch=100, lr=0.01):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    corpus_emb_1 = torch.tensor(corpus_emb_1, device=device, dtype=torch.float32)
    corpus_emb_2 = torch.tensor(corpus_emb_2, device=device, dtype=torch.float32)
    test_emb_1 = torch.tensor(test_emb_1, device=device, dtype=torch.float32)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_emb_1, train_emb_2 = corpus_emb_1[overlap_ids], corpus_emb_2[overlap_ids]
    model = LinearModel(train_emb_1.shape[1], train_emb_2.shape[1]).to(device)
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.MSELoss(reduction="mean")
    for _ in range(epoch):
        optimizer.zero_grad()
        output = model(train_emb_1)
        loss = criterion(output, train_emb_2)
        loss = torch.linalg.norm(output - train_emb_2, axis=1).max()
        loss.backward()
        optimizer.step()
    model.eval()
    test_emb_1_transformed = model(test_emb_1).detach().cpu().numpy()
    print(f"train error: {torch.linalg.norm(train_emb_1 - train_emb_2, axis=1).max()} => {torch.linalg.norm(model(train_emb_1) - train_emb_2, axis=1).mean()}")
    alpha_error = torch.linalg.norm(model(train_emb_1) - train_emb_2, axis=1).detach().cpu().numpy()
    return test_emb_1_transformed, alpha_error
