import torch
from torch import nn
from src import utils
from src.navierstokes import utils as ns_utils


class DivergenceFreeNetwork(nn.Module):
    def __init__(self, num_components_per_func=32):
        super().__init__()
        self.transformed_funcs = self.setup_transformed_funcs(
            [
                self.func5,
                self.func6,
                self.func7,
                self.func8,
            ],
            num_components_per_func,
        )

    def forward(self, x):
        output = 0.0
        for func in self.transformed_funcs:
            output = output + func(x)
        return output

    @staticmethod
    def setup_transformed_funcs(functions, num_components):
        output = []
        for func in functions:
            for orientation_preserving in [True, False]:
                output.append(
                    utils.TransformedFunc2D(
                        num_components=num_components,
                        orientation_preserving=orientation_preserving,
                        func=func,
                        num_outputs=2,
                        apply_rotation=True,
                    )
                )
        return nn.ModuleList(output)

    @staticmethod
    def func1(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([y, x], dim=1)

    @staticmethod
    def func2(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([x, -y], dim=1)

    @staticmethod
    def func3(x):
        x, y = x[:, 0], x[:, 1]
        zeros = torch.zeros_like(x)
        return torch.stack([zeros, x], dim=1)

    @staticmethod
    def func4(x):
        x, y = x[:, 0], x[:, 1]
        zeros = torch.zeros_like(x)
        return torch.stack([y, zeros], dim=1)

    @staticmethod
    def func5(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([torch.cos(x + y), -torch.cos(x + y)], dim=1)

    @staticmethod
    def func6(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([torch.exp(x + y), -torch.exp(x + y)], dim=1)

    @staticmethod
    def func7(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([x * torch.cos(x * y), -y * torch.cos(x * y)], dim=1)

    @staticmethod
    def func8(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([torch.exp(x + y), -torch.exp(x + y)], dim=1)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.u_net = DivergenceFreeNetwork()
        self.p_net = ns_utils.MLP(num_outputs=1)

    def velocity(self, x, y):
        v = self.u_net(torch.stack([x, y], dim=1))
        return v[:, 0], v[:, 1]

    def pressure(self, x, y):
        return self.p_net(x, y)


def get_loss_function_and_network():
    collocation_points = ns_utils.CollocationPoints()
    net = Net()

    def loss(net):
        boundary = ns_utils.dirichlet_loss(net, collocation_points)
        ns_loss = ns_utils.ns_loss(net, *collocation_points.interior)
        return boundary + ns_loss

    return loss, net
