import torch
from torch import nn
from src.utils import d
from src.burgers import utils as burgers_utils


class Net(nn.Module):
    def __init__(self, num_components=100):
        super().__init__()
        self.x_shift = nn.Parameter(torch.rand(1, num_components))
        self.x_scale = nn.Parameter(10 * torch.rand(1, num_components))
        self.t_shift = nn.Parameter(torch.rand(1, num_components))
        self.out_shift = nn.Parameter(torch.rand(1, num_components))
        self.weights = nn.Parameter(torch.rand(1, num_components))
        self.k = burgers_utils.benchmark.Benchmark().k

    def forward(self, x):
        x, t = x[:, 0].unsqueeze(1), x[:, 1].unsqueeze(1)
        x_component = torch.sin(self.x_scale * x + self.x_shift)
        t_scale = self.x_scale**2 * self.k
        t_component = torch.exp(-t_scale * t + self.t_shift)
        phi = self.weights * x_component * t_component + self.out_shift
        phi_avg = phi.mean(dim=1).unsqueeze(1)
        u = -2 * self.k * d(phi_avg, x) / phi_avg

        return u.squeeze(1)


def get_loss_function_and_network():
    collocation_points = burgers_utils.CollocationPoints()
    net = Net()

    def loss(net):
        x, target = collocation_points.dirichlet_boundary
        x.requires_grad = True
        pred = net(x)

        assert target.shape == pred.shape
        return (pred - target).pow(2).mean()

    return loss, net
