import torch
from torch import nn
from src.burgers import utils as burgers_utils
from src import utils


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 1)

    def forward(self, x):
        out = self.fc1(x)
        out = torch.tanh(out)
        out = self.fc2(out)
        out = torch.tanh(out)
        out = self.fc3(out)
        out = torch.tanh(out)
        out = self.fc4(out)
        return out


def get_loss_function_and_network():
    collocation_points = burgers_utils.CollocationPoints()
    k = burgers_utils.benchmark.Benchmark().k
    net = Net()

    def loss(net):
        x, target = collocation_points.dirichlet_boundary
        preds = net(x).squeeze()
        assert preds.shape == target.shape
        boundary_loss = (preds - target).pow(2).mean()
        x, t = collocation_points.interior
        pde_loss = utils.WEIGHT_FACTOR * burgers_utils.pde_loss(net, x, t, k)
        return boundary_loss + pde_loss

    return loss, net
