import torch
from torch import nn
from src import utils
from src.laplace2 import utils as laplace_utils


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 64, dtype=torch.complex64)
        self.fc2 = nn.Linear(64, 64, dtype=torch.complex64)
        self.fc3 = nn.Linear(64, 64, dtype=torch.complex64)
        self.fc4 = nn.Linear(64, 1, dtype=torch.complex64)

    def forward(self, x):
        out = (x[:, 0] + 1j * x[:, 1]).unsqueeze(1)
        out = self.fc1(out)
        out = torch.sin(out)
        out = self.fc2(out)
        out = torch.sin(out)
        out = self.fc3(out)
        out = torch.sin(out)
        out = self.fc4(out)
        return out.real


def get_loss_function_and_network():
    collocation_points = laplace_utils.CollocationPoints()
    net = Net()

    def loss(net):
        combined_loss = laplace_utils.boundary_loss(net, collocation_points)
        return combined_loss

    return loss, net
