import torch
from torch import nn
from src.heat1 import utils as heat_utils
from src.heat1 import benchmark as heat_benchmark


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 1)
        self.a = nn.Parameter(torch.ones(1,1))
        self.n = 10

    def forward(self, x):
        out = self.a * self.n * self.fc1(x)
        out = torch.tanh(out)
        out = self.a * self.n * self.fc2(out)
        out = torch.tanh(out)
        out = self.a * self.n * self.fc3(out)
        out = torch.tanh(out)
        out = self.fc4(out)
        return out


def get_loss_function_and_network():
    collocation_points = heat_utils.CollocationPoints()
    D = heat_benchmark.Benchmark().D
    net = Net()

    def loss(net):
        combined_loss = heat_utils.boundary_loss(net, collocation_points)
        x, y, t = collocation_points.interior
        pde_loss = heat_utils.pde_loss(net, x, y, t, D)
        return combined_loss + pde_loss
    return loss, net
