import torch
from torch import nn
from src import utils
from src.navierstokes import utils as ns_utils


class DivergenceFreeNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = ns_utils.MLP(num_outputs=1)

    def forward(self, x):
        x, y = x[:, 0], x[:, 1]
        phi = self.base(x, y)
        u = utils.d(phi, y)
        v = -utils.d(phi, x)
        return torch.stack([u, v], dim=1)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.u_net = DivergenceFreeNetwork()
        self.p_net = ns_utils.MLP(num_outputs=1)

    def velocity(self, x, y):
        v = self.u_net(torch.stack([x, y], dim=1))
        return v[:, 0], v[:, 1]

    def pressure(self, x, y):
        return self.p_net(x, y)


def get_loss_function_and_network():
    collocation_points = ns_utils.CollocationPoints()
    net = Net()

    def loss(net):
        boundary = ns_utils.dirichlet_loss(net, collocation_points)
        ns_loss = ns_utils.ns_loss(net, *collocation_points.interior)
        return boundary + ns_loss

    return loss, net
