import torch
from torch import nn
from src.navierstokes import utils as ns_utils
from src import utils


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.u_net = ns_utils.MLP(num_outputs=2)
        self.p_net = ns_utils.MLP(num_outputs=1)

    def velocity(self, x, y):
        return self.u_net(x, y)

    def pressure(self, x, y):
        return self.p_net(x, y)


def get_loss_function_and_network():
    collocation_points = ns_utils.CollocationPoints()
    net = Net()

    def loss(net):
        boundary = utils.WEIGHT_FACTOR * ns_utils.dirichlet_loss(net, collocation_points)
        ns_loss = ns_utils.ns_loss(net, *collocation_points.interior)
        divergence_loss = ns_utils.divergence_loss(net, *collocation_points.interior)
        return boundary + ns_loss + divergence_loss

    return loss, net
